diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt new file mode 100644 index 00000000..69a0ec9d --- /dev/null +++ b/plugins/CMakeLists.txt @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +# WASI plug-in: WASI-Crypto proposal. +if(WASMEDGE_PLUGIN_WASI_CRYPTO) + add_subdirectory(wasi_crypto) +endif() + +# WASI plug-in: WASI-Http proposal. +if(WASMEDGE_PLUGIN_WASI_HTTP) + add_subdirectory(wasi_http) +endif() + +# WASI plug-in: WASI-Logging proposal. +if(WASMEDGE_PLUGIN_WASI_LOGGING) + # BUILTIN-PLUGIN: Add the wasi-logging plugin here after the new plugin + # architecture is ready in 0.15.0. +endif() + +# WASI plug-in: WASI-NN proposal with backends. +if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) + add_subdirectory(wasi_nn) +endif() + +# WASI plug-in: WASI-Poll proposal. +if(WASMEDGE_PLUGIN_WASI_POLL) + add_subdirectory(wasi_poll) +endif() + +# WasmEdge plug-in: wasm-bpf. +if(WASMEDGE_PLUGIN_WASM_BPF) + # wasm_bpf is currently supported only on Linux systems. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasm_bpf) + else() + message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") + endif() +endif() + +# WasmEdge plug-in: ffmpeg. +if(WASMEDGE_PLUGIN_FFMPEG) + add_subdirectory(wasmedge_ffmpeg) +endif() + +# WasmEdge plug-in: Image. +if(WASMEDGE_PLUGIN_IMAGE) + # wasmedge_image is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_image) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Image plug-in now.") + endif() +endif() + +# WasmEdge plug-in: LLMC. +if(WASMEDGE_PLUGIN_LLMC) + add_subdirectory(wasmedge_llmc) +endif() + +# WasmEdge plug-in: OCR. +if(WASMEDGE_PLUGIN_OCR) + add_subdirectory(wasmedge_ocr) +endif() + +# WasmEdge plug-in: OpenCV-mini. +if(WASMEDGE_PLUGIN_OPENCVMINI) + # wasmedge_opencvmini is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_opencvmini) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_OpenCVMini plug-in now.") + endif() +endif() + +# WasmEdge plug-in: Process. +if(WASMEDGE_PLUGIN_PROCESS) + # wasmedge_process is currently supported only on Linux systems. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_process) + else() + message(WARNING "Only Linux platforms support WasmEdge_Process plug-in now.") + endif() +endif() + +# WasmEdge plug-in: Stable-diffusion. +if(WASMEDGE_PLUGIN_STABLEDIFFUSION) + # wasmedge_stablediffusion is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_stablediffusion) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_StableDiffusion plug-in now.") + endif() +endif() + +# WasmEdge plug-in: TensorFlow. +if(WASMEDGE_PLUGIN_TENSORFLOW) + # wasmedge_tensorflow is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflow) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Tensorflow plug-in now.") + endif() +endif() + +# WasmEdge plug-in: TensorFlow-Lite. +if(WASMEDGE_PLUGIN_TENSORFLOWLITE) + # wasmedge_tensorflowlite is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflowlite) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_TensorflowLite plug-in now.") + endif() +endif() + +# WasmEdge plug-in: zlib. +if(WASMEDGE_PLUGIN_ZLIB) + add_subdirectory(wasmedge_zlib) +endif() diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt new file mode 100644 index 00000000..48ae00c9 --- /dev/null +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +set(OPENSSL_USE_STATIC_LIBS ON) +find_package(OpenSSL REQUIRED) + +wasmedge_add_library(wasmedgePluginWasiCrypto + SHARED + ctx.cpp + asymmetric_common/ctx.cpp + asymmetric_common/func.cpp + asymmetric_common/keypair.cpp + asymmetric_common/module.cpp + asymmetric_common/publickey.cpp + asymmetric_common/secretkey.cpp + common/array_output.cpp + common/ctx.cpp + common/func.cpp + common/module.cpp + common/options.cpp + kx/ctx.cpp + kx/dh/ecdsa.cpp + kx/dh/x25519.cpp + kx/func.cpp + kx/kx.cpp + kx/module.cpp + kx/options.cpp + signatures/ctx.cpp + signatures/ecdsa.cpp + signatures/eddsa.cpp + signatures/func.cpp + signatures/module.cpp + signatures/options.cpp + signatures/rsa.cpp + signatures/signatures.cpp + signatures/signstate.cpp + signatures/verificationstate.cpp + symmetric/aeads.cpp + symmetric/ctx.cpp + symmetric/func.cpp + symmetric/hash.cpp + symmetric/kdf.cpp + symmetric/key.cpp + symmetric/mac.cpp + symmetric/module.cpp + symmetric/options.cpp + symmetric/state.cpp + symmetric/tag.cpp + utils/evp_wrapper.cpp + utils/hostfunction.cpp +) + +target_compile_options(wasmedgePluginWasiCrypto + PUBLIC + -DWASMEDGE_PLUGIN + -DOPENSSL_API_COMPAT=0x10100000L +) + +target_include_directories(wasmedgePluginWasiCrypto + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/thirdparty +) + +target_link_libraries(wasmedgePluginWasiCrypto + PUBLIC + OpenSSL::Crypto +) +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiCrypto + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiCrypto + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasiCrypto + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_crypto/asymmetric_common/ctx.cpp b/plugins/wasi_crypto/asymmetric_common/ctx.cpp new file mode 100644 index 00000000..a610a654 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/ctx.cpp @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ctx.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect<__wasi_array_output_t> +Context::publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return PublicKeyManager.get(PkHandle) + .and_then([Encoding](auto &&Pk) { + return AsymmetricCommon::pkExportData(std::forward(Pk), + Encoding); + }) + .and_then([this](auto &&Data) { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::publickeyVerify(__wasi_publickey_t PkHandle) noexcept { + return PublicKeyManager.get(PkHandle).and_then(AsymmetricCommon::pkVerify); +} + +WasiCryptoExpect +Context::publickeyClose(__wasi_publickey_t PkHandle) noexcept { + return PublicKeyManager.close(PkHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return SecretKeyManager.get(SkHandle) + .and_then([Encoding](auto &&Sk) { + return AsymmetricCommon::skExportData(std::forward(Sk), + Encoding); + }) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::secretkeyClose(__wasi_secretkey_t SkHandle) noexcept { + return SecretKeyManager.close(SkHandle); +} + +WasiCryptoExpect<__wasi_publickey_t> +Context::publickeyFromSecretkey(__wasi_secretkey_t SkHandle) noexcept { + return SecretKeyManager.get(SkHandle) + .and_then(AsymmetricCommon::skPublicKey) + .and_then([this](auto &&Pk) noexcept { + return PublicKeyManager.registerManager(std::forward(Pk)); + }); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return KeyPairManager.get(KpHandle) + .and_then([Encoding](auto &&Kp) noexcept { + return AsymmetricCommon::kpExportData(std::forward(Kp), + Encoding); + }) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect<__wasi_publickey_t> +Context::keypairPublickey(__wasi_keypair_t KpHandle) noexcept { + return KeyPairManager.get(KpHandle) + .and_then(AsymmetricCommon::kpPublicKey) + .and_then([this](auto &&Pk) noexcept { + return PublicKeyManager.registerManager(std::forward(Pk)); + }); +} + +WasiCryptoExpect<__wasi_secretkey_t> +Context::keypairSecretkey(__wasi_keypair_t KpHandle) noexcept { + return KeyPairManager.get(KpHandle) + .and_then(AsymmetricCommon::kpSecretKey) + .and_then([this](auto &&Sk) noexcept { + return SecretKeyManager.registerManager(std::forward(Sk)); + }); +} + +WasiCryptoExpect +Context::keypairClose(__wasi_keypair_t KpHandle) noexcept { + return KeyPairManager.close(KpHandle); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairFromPkAndSk(__wasi_publickey_t PkHandle, + __wasi_secretkey_t SkHandle) noexcept { + auto Pk = PublicKeyManager.get(PkHandle); + if (!Pk) { + return WasiCryptoUnexpect(Pk); + } + + auto Sk = SecretKeyManager.get(SkHandle); + if (!Sk) { + return WasiCryptoUnexpect(Sk); + } + + return AsymmetricCommon::kpFromPkAndSk(*Pk, *Sk).and_then( + [this](auto &&Kp) noexcept { + return KeyPairManager.registerManager(std::forward(Kp)); + }); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairGenerate(AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept { + return mapAndTransposeOptional( + OptOptionsHandle, + [this](__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.get(OptionsHandle); + }) + .and_then([Alg](auto &&OptOptions) noexcept { + return AsymmetricCommon::generateKp( + Alg, asOptionalRef(std::forward(OptOptions))); + }) + .and_then([this](auto &&Kp) noexcept { + return KeyPairManager.registerManager(std::forward(Kp)); + }); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairImport(AsymmetricCommon::Algorithm Alg, + Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return AsymmetricCommon::importKp(Alg, Encoded, Encoding) + .and_then([this](auto &&Kp) noexcept { + return KeyPairManager.registerManager(std::forward(Kp)); + }); +} + +WasiCryptoExpect<__wasi_publickey_t> +Context::publickeyImport(AsymmetricCommon::Algorithm Alg, + Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return AsymmetricCommon::importPk(Alg, Encoded, Encoding) + .and_then([this](auto &&Pk) noexcept { + return PublicKeyManager.registerManager(std::forward(Pk)); + }); +} + +WasiCryptoExpect<__wasi_secretkey_t> +Context::secretkeyImport(AsymmetricCommon::Algorithm Alg, + Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return AsymmetricCommon::importSk(Alg, Encoded, Encoding) + .and_then([this](auto &&Sk) noexcept { + return SecretKeyManager.registerManager(std::forward(Sk)); + }); +} + +WasiCryptoExpect<__wasi_keypair_t> Context::keypairGenerateManaged( + __wasi_secrets_manager_t, AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept { + if (std::holds_alternative(Alg)) { + return keypairGenerate(Alg, OptOptionsHandle); + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect Context::keypairStoreManaged(__wasi_secrets_manager_t, + __wasi_keypair_t, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_version_t> +Context::keypairReplaceManaged(__wasi_secrets_manager_t, __wasi_keypair_t, + __wasi_keypair_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect> +Context::keypairId(__wasi_keypair_t, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairFromId(__wasi_secrets_manager_t, Span, + __wasi_version_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/ecdsa.h b/plugins/wasi_crypto/asymmetric_common/ecdsa.h new file mode 100644 index 00000000..85990ca3 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/ecdsa.h @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/ecdsa.h - Ecdsa alg-===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the ECDSA algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +template +class Ecdsa { + inline static const size_t UnCompressedPkSize = 65; + inline static const size_t CompressedPkSize = 33; + constexpr static size_t getRawPkSize(bool Compressed) { + return Compressed ? CompressedPkSize : UnCompressedPkSize; + } + + inline static const size_t SkSize = 32; + + constexpr static point_conversion_form_t getForm(bool Compressed) noexcept { + return Compressed ? POINT_CONVERSION_COMPRESSED + : POINT_CONVERSION_UNCOMPRESSED; + } + +public: + class PublicKeyBase { + public: + PublicKeyBase(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKeyBase(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_PUBLICKEY_ENCODING_PEM: + return importPem(Encoded); + case __WASI_PUBLICKEY_ENCODING_SEC: + return importSec(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_SEC: + return exportSec(false); + case __WASI_PUBLICKEY_ENCODING_PEM: + return exportPem(false); + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return exportPkcs8(false); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect verify() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + + protected: + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPUBKEY(Encoded)}); + } + + static WasiCryptoExpect + importPem(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPUBKEY(Encoded)}); + } + + static WasiCryptoExpect + importSec(Span Encoded) noexcept { + EcKeyPtr EcCtx{EC_KEY_new_by_curve_name(CurveNid)}; + EcPointPtr Pk{EC_POINT_new(EC_KEY_get0_group(EcCtx.get()))}; + ensureOrReturn(EC_POINT_oct2point(EC_KEY_get0_group(EcCtx.get()), + Pk.get(), Encoded.data(), + Encoded.size(), nullptr), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + opensslCheck(EC_KEY_set_public_key(EcCtx.get(), Pk.get())); + + EvpPkeyPtr Ctx{EVP_PKEY_new()}; + opensslCheck(EVP_PKEY_set1_EC_KEY(Ctx.get(), EcCtx.get())); + + return checkValid(std::move(Ctx)); + } + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); + ensureOrReturn(Group, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(EC_GROUP_get_curve_name(Group) == CurveNid, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; + } + + WasiCryptoExpect> + exportSec(bool Compressed) const noexcept { + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + std::vector Res(getRawPkSize(Compressed)); + opensslCheck(EC_POINT_point2oct( + EC_KEY_get0_group(EcCtx), EC_KEY_get0_public_key(EcCtx), + getForm(Compressed), Res.data(), Res.size(), nullptr)); + return Res; + } + + WasiCryptoExpect> + exportPem(bool Compressed) const noexcept { + EC_KEY_set_conv_form( + const_cast(EVP_PKEY_get0_EC_KEY(Ctx.get())), + getForm(Compressed)); + + return pemWritePUBKEY(Ctx.get()); + } + + WasiCryptoExpect> + exportPkcs8(bool Compressed) const noexcept { + EC_KEY_set_conv_form( + const_cast(EVP_PKEY_get0_EC_KEY(Ctx.get())), + getForm(Compressed)); + + return i2dPUBKEY(Ctx.get()); + } + + SharedEvpPkey Ctx; + }; + + class SecretKeyBase { + public: + SecretKeyBase(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + SecretKeyBase(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: + return importRaw(Encoded); + case __WASI_SECRETKEY_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_SECRETKEY_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: + return exportRaw(); + case __WASI_SECRETKEY_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_SECRETKEY_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; + } + + WasiCryptoExpect toKeyPair(const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + + protected: + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); + } + + static WasiCryptoExpect + importPem(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); + } + + static WasiCryptoExpect + importRaw(Span Encoded) noexcept { + EcKeyPtr EcCtx{EC_KEY_new_by_curve_name(CurveNid)}; + ensureOrReturn(Encoded.size() == SkSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + BnPtr Sk{ + BN_bin2bn(Encoded.data(), static_cast(Encoded.size()), nullptr)}; + ensureOrReturn(EC_KEY_set_private_key(EcCtx.get(), Sk.get()), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + + EvpPkeyPtr Ctx{EVP_PKEY_new()}; + opensslCheck(EVP_PKEY_set1_EC_KEY(Ctx.get(), EcCtx.get())); + + return Ctx; + } + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); + ensureOrReturn(Group, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(EC_GROUP_get_curve_name(Group) == CurveNid, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; + } + + WasiCryptoExpect exportPkcs8() const noexcept { + EVP_PKEY *Key = Ctx.get(); + BioPtr Bio{BIO_new(BIO_s_mem())}; + opensslCheck(i2d_PKCS8PrivateKey_bio(Bio.get(), Key, nullptr, nullptr, 0, + nullptr, nullptr)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + SecretVec Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; + } + + WasiCryptoExpect exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); + } + + WasiCryptoExpect exportRaw() const noexcept { + // Must equal to SkSize, not check. + const BIGNUM *Sk = + EC_KEY_get0_private_key(EVP_PKEY_get0_EC_KEY(Ctx.get())); + SecretVec Res(SkSize); + opensslCheck(BN_bn2bin(Sk, Res.data())); + + return Res; + } + + SharedEvpPkey Ctx; + }; + + class KeyPairBase { + public: + KeyPairBase(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + KeyPairBase(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + generate(OptionalRef) noexcept { + EvpPkeyCtxPtr ParamCtx{EVP_PKEY_CTX_new_id(EVP_PKEY_EC, nullptr)}; + EVP_PKEY_keygen_init(ParamCtx.get()); + EVP_PKEY_CTX_set_ec_paramgen_curve_nid(ParamCtx.get(), CurveNid); + + EVP_PKEY *Key = nullptr; + opensslCheck(EVP_PKEY_keygen(ParamCtx.get(), &Key)); + + return EvpPkeyPtr{Key}; + } + + static WasiCryptoExpect + import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: + return importRaw(Encoded); + case __WASI_KEYPAIR_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_KEYPAIR_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; + } + + WasiCryptoExpect secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; + } + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: + return exportRaw(); + case __WASI_KEYPAIR_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_KEYPAIR_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + protected: + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); + } + + static WasiCryptoExpect + importPem(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); + } + + static WasiCryptoExpect + importRaw(Span Encoded) noexcept { + ensureOrReturn(Encoded.size() == SkSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + EcKeyPtr EcCtx{EC_KEY_new_by_curve_name(CurveNid)}; + BnPtr Sk{ + BN_bin2bn(Encoded.data(), static_cast(Encoded.size()), nullptr)}; + ensureOrReturn(EC_KEY_set_private_key(EcCtx.get(), Sk.get()), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + + // Calculate and set Pk. + EcPointPtr Pk{EC_POINT_new(EC_KEY_get0_group(EcCtx.get()))}; + opensslCheck(EC_POINT_mul(EC_KEY_get0_group(EcCtx.get()), Pk.get(), + Sk.get(), nullptr, nullptr, nullptr)); + opensslCheck(EC_KEY_set_public_key(EcCtx.get(), Pk.get())); + + EvpPkeyPtr Ctx{EVP_PKEY_new()}; + opensslCheck(EVP_PKEY_set1_EC_KEY(Ctx.get(), EcCtx.get())); + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; + } + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + // Curve id check. + const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); + ensureOrReturn(Group, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(EC_GROUP_get_curve_name(Group) == CurveNid, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + // Have public key. + ensureOrReturn(EC_KEY_get0_public_key(EcCtx), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; + } + + WasiCryptoExpect exportPkcs8() const noexcept { + return i2dPrivateKey(Ctx.get()); + } + + WasiCryptoExpect exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); + } + + WasiCryptoExpect exportRaw() const noexcept { + const BIGNUM *Sk = + EC_KEY_get0_private_key(EVP_PKEY_get0_EC_KEY(Ctx.get())); + SecretVec Res(SkSize); + opensslCheck(BN_bn2bin(Sk, Res.data())); + + return Res; + } + + SharedEvpPkey Ctx; + }; +}; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/func.cpp b/plugins/wasi_crypto/asymmetric_common/func.cpp new file mode 100644 index 00000000..698f72a6 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/func.cpp @@ -0,0 +1,517 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "asymmetric_common/func.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +Expect KeypairGenerate::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptionsHandle = + MemInst->getPointer(OptOptionsHandlePtr); + checkExist(OptOptionsHandle); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairGenerate(WasiAlg, *OptOptionsHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairImport::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); + + const auto WasiEncoding = cast<__wasi_keypair_encoding_e_t>(Encoding); + checkExist(WasiEncoding); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairImport(WasiAlg, Encoded, *WasiEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairGenerateManaged::body( + const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, uint32_t KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptionsHandle = + MemInst->getPointer(OptOptionsHandlePtr); + checkExist(OptOptionsHandle); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairGenerateManaged(SecretsManagerHandle, WasiAlg, + *OptOptionsHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairStoreManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t KpHandle, uint32_t KpIdPtr, + uint32_t KpIdMaxLen) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; + const auto KpId = MemInst->getSpan(KpIdPtr, WasiKpIdMaxLen); + checkRangeExist(KpId, WasiKpIdMaxLen); + + if (auto Res = Ctx.keypairStoreManaged(SecretsManagerHandle, KpHandle, KpId); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairReplaceManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t OldKpHandle, + int32_t NewKpHandle, + uint32_t /* Out */ KpVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const KpVersion = MemInst->getPointer<__wasi_version_t *>(KpVersionPtr); + checkExist(KpVersion); + + if (auto Res = Ctx.keypairReplaceManaged(SecretsManagerHandle, OldKpHandle, + NewKpHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KpVersion = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairId::body(const Runtime::CallingFrame &Frame, + int32_t KpHandle, uint32_t KpIdPtr, + uint32_t KpIdMaxLen, + uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KpVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; + const auto KpId = MemInst->getSpan(KpIdPtr, WasiKpIdMaxLen); + checkRangeExist(KpId, WasiKpIdMaxLen); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + auto *const Version = MemInst->getPointer<__wasi_version_t *>(KpVersionPtr); + checkExist(Version); + + if (auto Res = Ctx.keypairId(KpHandle, KpId); unlikely(!Res)) { + return Res.error(); + } else { + auto [ResSize, ResVersion] = *Res; + + auto SafeResSize = toWasiSize(ResSize); + if (unlikely(!SafeResSize)) { + return SafeResSize.error(); + } + + *Size = *SafeResSize; + + *Version = ResVersion; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairFromId::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + uint32_t KpIdPtr, uint32_t KpIdLen, + uint64_t KpVersion, + uint32_t /* Out */ KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiKpIdLen = KpIdLen; + const auto KpId = MemInst->getSpan(KpIdPtr, WasiKpIdLen); + checkRangeExist(KpId, WasiKpIdLen); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairFromId(SecretsManagerHandle, KpId, KpVersion); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairFromPkAndSk::body(const Runtime::CallingFrame &Frame, + int32_t PkHandle, int32_t SkHandle, + uint32_t /* Out */ KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairFromPkAndSk(PkHandle, SkHandle); unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairExport::body(const Runtime::CallingFrame &Frame, + int32_t KpHandle, uint32_t KpEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + __wasi_keypair_encoding_e_t WasiKpEncoding; + if (auto Res = cast<__wasi_keypair_encoding_e_t>(KpEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + WasiKpEncoding = *Res; + } + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.keypairExport(KpHandle, WasiKpEncoding); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairPublickey::body(const Runtime::CallingFrame &Frame, + int32_t KpHandle, + uint32_t /* Out */ PkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const PkHandle = MemInst->getPointer<__wasi_keypair_t *>(PkHandlePtr); + checkExist(PkHandle); + + if (auto Res = Ctx.keypairPublickey(KpHandle); unlikely(!Res)) { + return Res.error(); + } else { + *PkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairSecretkey::body(const Runtime::CallingFrame &Frame, + int32_t KpHandle, + uint32_t /* Out */ SkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const SkHandle = MemInst->getPointer<__wasi_keypair_t *>(SkHandlePtr); + checkExist(SkHandle); + + if (auto Res = Ctx.keypairSecretkey(KpHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairClose::body(const Runtime::CallingFrame &, + int32_t KpHandle) { + if (auto Res = Ctx.keypairClose(KpHandle); unlikely(!Res)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect PublickeyImport::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ PkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); + + __wasi_publickey_encoding_e_t WasiPkEncoding; + if (auto Res = cast<__wasi_publickey_encoding_e_t>(Encoding); !Res) { + return Res.error(); + } else { + WasiPkEncoding = *Res; + } + + auto *const PkHandle = MemInst->getPointer<__wasi_publickey_t *>(PkHandlePtr); + checkExist(PkHandle); + + if (auto Res = Ctx.publickeyImport(WasiAlg, Encoded, WasiPkEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *PkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +PublickeyExport::body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + uint32_t PkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + __wasi_publickey_encoding_e_t WasiPkEncoding; + if (auto Res = cast<__wasi_publickey_encoding_e_t>(PkEncoding); !Res) { + return Res.error(); + } else { + WasiPkEncoding = *Res; + } + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.publickeyExport(PkHandle, WasiPkEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect PublickeyVerify::body(const Runtime::CallingFrame &, + int32_t PkHandle) { + if (auto Res = Ctx.publickeyVerify(PkHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +PublickeyFromSecretkey::body(const Runtime::CallingFrame &Frame, + int32_t SkHandle, uint32_t /* Out */ PkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const PkHandle = MemInst->getPointer<__wasi_publickey_t *>(PkHandlePtr); + checkExist(PkHandle); + + if (auto Res = Ctx.publickeyFromSecretkey(SkHandle); unlikely(!Res)) { + return Res.error(); + } else { + *PkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect PublickeyClose::body(const Runtime::CallingFrame &, + int32_t PkHandle) { + if (auto Res = Ctx.publickeyClose(PkHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect SecretkeyImport::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ SkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); + + auto WasiEncoding = cast<__wasi_secretkey_encoding_e_t>(Encoding); + if (!WasiEncoding) { + return WasiEncoding.error(); + } + + auto *const SkHandle = MemInst->getPointer<__wasi_secretkey_t *>(SkHandlePtr); + checkExist(SkHandle); + + if (auto Res = Ctx.secretkeyImport(WasiAlg, Encoded, *WasiEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *SkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +SecretkeyExport::body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t SkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + __wasi_secretkey_encoding_e_t WasiSkEncoding; + if (auto Res = cast<__wasi_secretkey_encoding_e_t>(SkEncoding); !Res) { + return Res.error(); + } else { + WasiSkEncoding = *Res; + } + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.secretkeyExport(SkHandle, WasiSkEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect SecretkeyClose::body(const Runtime::CallingFrame &, + int32_t Sk) { + if (auto Res = Ctx.secretkeyClose(Sk); unlikely(!Res)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/func.h b/plugins/wasi_crypto/asymmetric_common/func.h new file mode 100644 index 00000000..3c5411a5 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/func.h @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/func.h -------------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/hostfunction.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +class KeypairGenerate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairGenerateManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, uint32_t KpHandlePtr); +}; + +class KeypairStoreManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, int32_t KpHandle, + uint32_t KpIdPtr, uint32_t KpIdMaxLen); +}; + +class KeypairReplaceManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, int32_t OldKpHandle, + int32_t NewKpHandle, uint32_t /* Out */ KpVersionPtr); +}; + +class KeypairId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t KpIdPtr, uint32_t KpIdMaxLen, + uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KpVersionPtr); +}; + +class KeypairFromId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, uint32_t KpIdPtr, + uint32_t KpIdLen, uint64_t KpVersion, + uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairFromPkAndSk : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + int32_t SkHandle, uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t KpEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class KeypairPublickey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t /* Out */ PkHandlePtr); +}; + +class KeypairSecretkey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t /* Out */ SkHandlePtr); +}; + +class KeypairClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle); +}; + +class PublickeyImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ PkHandlePtr); +}; + +class PublickeyExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + uint32_t PkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class PublickeyVerify : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle); +}; + +class PublickeyFromSecretkey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t /* Out */ PkHandlePtr); +}; + +class PublickeyClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle); +}; + +class SecretkeyImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ SkHandlePtr); +}; + +class SecretkeyExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t SkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class SecretkeyClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle); +}; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.cpp b/plugins/wasi_crypto/asymmetric_common/keypair.cpp new file mode 100644 index 00000000..73a602db --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/keypair.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "asymmetric_common/keypair.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +WasiCryptoExpect +importKp(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + return decltype(Factory)::KeyPair::import(Encoded, Encoding); + }, + Alg); +} + +namespace { +/// Correspond signatures: +/// WasiCryptoExpect generate(OptionalRef); +/// is used to get the `OptionsType`. +template struct KpGenerateTrait; +template +struct KpGenerateTrait (*)( + OptionalRef) noexcept> { + using Options = OptionsType; +}; +template +using OptionsType = + typename KpGenerateTrait::Options; +} // namespace + +WasiCryptoExpect +generateKp(AsymmetricCommon::Algorithm Alg, + OptionalRef OptOptions) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + using RequiredOptionsType = OptionsType; + return transposeOptionalRef( + OptOptions, + [](auto &&Options) noexcept + -> WasiCryptoExpect> { + using InOptionsType = std::decay_t; + if constexpr (std::is_same_v) { + return Options; + } else { + return WasiCryptoUnexpect( + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + }) + .and_then([](auto OptRequiredOptions) noexcept { + return decltype(Factory)::KeyPair::generate(OptRequiredOptions); + }); + }, + Alg); +} + +namespace { +/// Correspond signatures: +/// WasiCryptoExpect Sk::toKeyPair(const PublicKeyType&); +/// is used to get the `PublicKeyType`. +template struct KpFromPkAndSkTrait; +template +struct KpFromPkAndSkTrait (SecretKeyType::*)( + const PublicKeyType &) const noexcept> { + using PublicKey = PublicKeyType; +}; +template +using PkType = typename KpFromPkAndSkTrait::PublicKey; +} // namespace + +WasiCryptoExpect kpFromPkAndSk(const PkVariant &PkVariant, + const SkVariant &SkVariant) noexcept { + return std::visit( + [](const auto &Pk, + const auto &Sk) noexcept -> WasiCryptoExpect { + using RequiredPkType = PkType>; + using InPkType = std::decay_t; + if constexpr (std::is_same_v) { + return Sk.toKeyPair(Pk); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_KEY); + } + }, + PkVariant, SkVariant); +} + +WasiCryptoExpect +kpExportData(const KpVariant &KpVariant, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](const auto &Kp) noexcept { return Kp.exportData(Encoding); }, + KpVariant); +} + +WasiCryptoExpect kpPublicKey(const KpVariant &KpVariant) noexcept { + return std::visit( + [](const auto &Kp) noexcept { + return Kp.publicKey().map([](auto &&Pk) noexcept { + return PkVariant{std::forward(Pk)}; + }); + }, + KpVariant); +} + +WasiCryptoExpect kpSecretKey(const KpVariant &KpVariant) noexcept { + return std::visit( + [](const auto &Kp) noexcept { + return Kp.secretKey().map([](auto &&Sk) noexcept { + return SkVariant{std::forward(Sk)}; + }); + }, + KpVariant); +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.h b/plugins/wasi_crypto/asymmetric_common/keypair.h new file mode 100644 index 00000000..0016b612 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/keypair.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/keypair.h ----------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common Keypair of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/publickey.h" +#include "asymmetric_common/registered.h" +#include "asymmetric_common/secretkey.h" +#include "common/options.h" +#include "utils/error.h" +#include "utils/optional.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +using KpVariant = RegistedAlg::KpVariant; + +WasiCryptoExpect +importKp(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect +generateKp(AsymmetricCommon::Algorithm Alg, + OptionalRef OptOptions) noexcept; + +WasiCryptoExpect +kpExportData(const KpVariant &KpVariant, + __wasi_keypair_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect kpPublicKey(const KpVariant &KpVariant) noexcept; + +WasiCryptoExpect kpSecretKey(const KpVariant &KpVariant) noexcept; + +WasiCryptoExpect kpFromPkAndSk(const PkVariant &Pk, + const SkVariant &Sk) noexcept; +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/module.cpp b/plugins/wasi_crypto/asymmetric_common/module.cpp new file mode 100644 index 00000000..ed3b3448 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/module.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "asymmetric_common/module.h" +#include "asymmetric_common/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoAsymmetricCommonModule::WasiCryptoAsymmetricCommonModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_asymmetric_common"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("keypair_generate", + std::make_unique(*Ctx)); + addHostFunc("keypair_import", + std::make_unique(*Ctx)); + addHostFunc("keypair_generate_managed", + std::make_unique(*Ctx)); + addHostFunc("keypair_store_managed", + std::make_unique(*Ctx)); + addHostFunc("keypair_replace_managed", + std::make_unique(*Ctx)); + addHostFunc("keypair_id", + std::make_unique(*Ctx)); + addHostFunc("keypair_from_id", + std::make_unique(*Ctx)); + addHostFunc("keypair_from_pk_and_sk", + std::make_unique(*Ctx)); + addHostFunc("keypair_export", + std::make_unique(*Ctx)); + addHostFunc("keypair_publickey", + std::make_unique(*Ctx)); + addHostFunc("keypair_secretkey", + std::make_unique(*Ctx)); + addHostFunc("keypair_close", + std::make_unique(*Ctx)); + addHostFunc("publickey_import", + std::make_unique(*Ctx)); + addHostFunc("publickey_export", + std::make_unique(*Ctx)); + addHostFunc("publickey_verify", + std::make_unique(*Ctx)); + addHostFunc("publickey_from_secretkey", + std::make_unique(*Ctx)); + addHostFunc("publickey_close", + std::make_unique(*Ctx)); + addHostFunc("secretkey_import", + std::make_unique(*Ctx)); + addHostFunc("secretkey_export", + std::make_unique(*Ctx)); + addHostFunc("secretkey_close", + std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/module.h b/plugins/wasi_crypto/asymmetric_common/module.h new file mode 100644 index 00000000..9aa39cf4 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/module.h @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/module.h - Asym ----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto asymmetric_common +/// module class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoAsymmetricCommonModule + : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoAsymmetricCommonModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.cpp b/plugins/wasi_crypto/asymmetric_common/publickey.cpp new file mode 100644 index 00000000..2ee7d84a --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/publickey.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "asymmetric_common/publickey.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +WasiCryptoExpect +importPk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + return decltype(Factory)::PublicKey::import(Encoded, Encoding); + }, + Alg); +} + +WasiCryptoExpect> +pkExportData(const PkVariant &PkVariant, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](const auto &Pk) noexcept { return Pk.exportData(Encoding); }, + PkVariant); +} + +WasiCryptoExpect pkVerify(const PkVariant &PkVariant) noexcept { + return std::visit([](const auto &Pk) noexcept { return Pk.verify(); }, + PkVariant); +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.h b/plugins/wasi_crypto/asymmetric_common/publickey.h new file mode 100644 index 00000000..6148dcbb --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/publickey.h @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/publickey.h --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common PublicKey of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "asymmetric_common/registered.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +using PkVariant = RegistedAlg::PkVariant; + +WasiCryptoExpect +importPk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect> +pkExportData(const PkVariant &PkVariant, + __wasi_publickey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect pkVerify(const PkVariant &PkVariant) noexcept; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/registered.h b/plugins/wasi_crypto/asymmetric_common/registered.h new file mode 100644 index 00000000..f74cb0a0 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/registered.h @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric/registered.h - Registered -===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register asymmetric common algorithm definitions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/registered.h" +#include "signatures/registered.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +template struct Registered { + using PkVariant = std::variant; + using SkVariant = std::variant; + using KpVariant = std::variant; + using Variant = std::variant; +}; + +template +struct Registered, Kx::Registered> { + using Alg = Registered; +}; + +/// Combine the signatures and kx algorithms. +using RegistedAlg = Registered::Alg; + +using Algorithm = RegistedAlg::Variant; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/secretkey.cpp b/plugins/wasi_crypto/asymmetric_common/secretkey.cpp new file mode 100644 index 00000000..55f518e7 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/secretkey.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "asymmetric_common/secretkey.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +WasiCryptoExpect +importSk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + return decltype(Factory)::SecretKey::import(Encoded, Encoding); + }, + Alg); +} + +WasiCryptoExpect +skExportData(const SkVariant &SkVariant, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](const auto &Sk) noexcept { return Sk.exportData(Encoding); }, + SkVariant); +} + +WasiCryptoExpect skPublicKey(const SkVariant &SkVariant) noexcept { + return std::visit( + [](const auto &Sk) noexcept { + return Sk.publicKey().map([](auto &&Pk) noexcept { + return PkVariant{std::forward(Pk)}; + }); + }, + SkVariant); +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/secretkey.h b/plugins/wasi_crypto/asymmetric_common/secretkey.h new file mode 100644 index 00000000..b9f3ec59 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/secretkey.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/secretkey.h --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common SecretKey of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "asymmetric_common/publickey.h" +#include "asymmetric_common/registered.h" +#include "wasi_crypto/api.hpp" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +using SkVariant = RegistedAlg::SkVariant; + +WasiCryptoExpect +importSk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect +skExportData(const SkVariant &SkVariant, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect skPublicKey(const SkVariant &SkVariant) noexcept; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/array_output.cpp b/plugins/wasi_crypto/common/array_output.cpp new file mode 100644 index 00000000..2085201c --- /dev/null +++ b/plugins/wasi_crypto/common/array_output.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/array_output.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +std::tuple ArrayOutput::pull(Span Buf) noexcept { + std::scoped_lock Lock{Mutex}; + + using DataPosT = decltype(Data)::difference_type; + + size_t OutputSize = std::min(Buf.size(), Data.size() - Pos); + + std::copy(Data.begin() + static_cast(Pos), + Data.begin() + static_cast(Pos + OutputSize), + Buf.begin()); + Pos += OutputSize; + + return {OutputSize, Pos + OutputSize == Data.size()}; +} + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/array_output.h b/plugins/wasi_crypto/common/array_output.h new file mode 100644 index 00000000..4ca43e27 --- /dev/null +++ b/plugins/wasi_crypto/common/array_output.h @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/array_output.h - ArrayOutput --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the ArrayOutput class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +/// Functions returning arrays whose size is not constant or too large to be +/// safely allocated on the stack return a handle to an ArrayOutput type. +/// +/// More details: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#array-outputs +class ArrayOutput { +public: + ArrayOutput(const ArrayOutput &) noexcept = delete; + ArrayOutput &operator=(const ArrayOutput &) noexcept = delete; + ArrayOutput &operator=(ArrayOutput &&) noexcept = delete; + ArrayOutput(ArrayOutput &&) noexcept = delete; + + ArrayOutput(std::vector &&Data) noexcept : Data(std::move(Data)) {} + + ArrayOutput(SecretVec &&Data) noexcept : Data(std::move(Data)) {} + + /// Copy the contents to the @param Buf buffer. + /// Multiple calls are possible, and the total number of bytes to be read is + /// guaranteed to always match the data size. + /// + /// @returns a tuple of `{bytes read, all-pulled flag}`. The flag is true if + /// all data has been pulled. + std::tuple pull(Span Buf) noexcept; + + /// Return ArrayOutput data size. + size_t len() const noexcept { return Data.size(); } + +private: + const SecretVec Data; + size_t Pos = 0; + std::mutex Mutex; +}; + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/ctx.cpp b/plugins/wasi_crypto/common/ctx.cpp new file mode 100644 index 00000000..7c8e76b1 --- /dev/null +++ b/plugins/wasi_crypto/common/ctx.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ctx.h" +#include "common/array_output.h" +#include "common/options.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect +Context::arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) noexcept { + return ArrayOutputManager.get(ArrayOutputHandle) + .map(&Common::ArrayOutput::len); +} + +WasiCryptoExpect +Context::arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, + Span Buf) noexcept { + return ArrayOutputManager.get(ArrayOutputHandle) + .map([=](Common::ArrayOutput &ArrayOutput) noexcept { + auto [Size, AlreadyConsumed] = ArrayOutput.pull(Buf); + if (AlreadyConsumed) { + ArrayOutputManager.close(ArrayOutputHandle); + } + return Size; + }); +} + +WasiCryptoExpect<__wasi_options_t> +Context::optionsOpen(__wasi_algorithm_type_e_t AlgType) noexcept { + return OptionsManager.registerManager(Common::optionsOpen(AlgType)); +} + +WasiCryptoExpect +Context::optionsClose(__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.close(OptionsHandle); +} + +WasiCryptoExpect Context::optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Value) noexcept { + return OptionsManager.get(OptionsHandle) + .and_then([Name, Value](auto &&Options) noexcept { + return Common::optionsSet(Options, Name, Value); + }); +} + +WasiCryptoExpect Context::optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, + uint64_t Value) noexcept { + return OptionsManager.get(OptionsHandle) + .and_then([Name, Value](auto &&Options) noexcept { + return Common::optionsSetU64(Options, Name, Value); + }); +} + +WasiCryptoExpect +Context::optionsSetGuestBuffer(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Buf) noexcept { + return OptionsManager.get(OptionsHandle) + .and_then([Name, Buf](auto &&Options) noexcept { + return Common::optionsSetGuestBuffer(Options, Name, Buf); + }); +} + +WasiCryptoExpect<__wasi_secrets_manager_t> +Context::secretsManagerOpen(__wasi_opt_options_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect +Context::secretsManagerClose(__wasi_secrets_manager_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect +Context::secretsManagerInvalidate(__wasi_secrets_manager_t, Span, + __wasi_version_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/func.cpp b/plugins/wasi_crypto/common/func.cpp new file mode 100644 index 00000000..37bfb29f --- /dev/null +++ b/plugins/wasi_crypto/common/func.cpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +Expect ArrayOutputLen::body(const Runtime::CallingFrame &Frame, + int32_t ArrayOutputHandle, + uint32_t /* Out */ SizePtr) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.arrayOutputLen(ArrayOutputHandle).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect ArrayOutputPull::body(const Runtime::CallingFrame &Frame, + int32_t ArrayOutputHandle, + uint32_t BufPtr, uint32_t BufLen, + uint32_t /* Out */ SizePtr) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiBufLen = BufLen; + const auto Buf = MemInst->getSpan(BufPtr, WasiBufLen); + checkRangeExist(Buf, WasiBufLen); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = + Ctx.arrayOutputPull(ArrayOutputHandle, Buf).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsOpen::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, + uint32_t /* Out */ OptionsHandlePtr) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + __wasi_algorithm_type_e_t WasiAlgType; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType); unlikely(!Res)) { + return Res.error(); + } else { + WasiAlgType = *Res; + } + + auto *const OptionsHandle = + MemInst->getPointer<__wasi_options_t *>(OptionsHandlePtr); + checkExist(OptionsHandle); + + if (auto Res = Ctx.optionsOpen(WasiAlgType); unlikely(!Res)) { + return Res.error(); + } else { + *OptionsHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsClose::body(const Runtime::CallingFrame &, + int32_t OptionsHandle) { + if (auto Res = Ctx.optionsClose(OptionsHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsSet::body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t ValuePtr, + uint32_t ValueLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); + + const __wasi_size_t WasiValueLen = ValueLen; + const auto Value = MemInst->getSpan(ValuePtr, WasiValueLen); + checkRangeExist(Value, WasiValueLen); + + if (auto Res = Ctx.optionsSet(OptionsHandle, Name, Value); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsSetU64::body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint64_t Value) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); + + if (auto Res = Ctx.optionsSetU64(OptionsHandle, Name, Value); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsSetGuestBuffer::body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle, + uint32_t NamePtr, uint32_t NameLen, + uint32_t BufPtr, uint32_t BufLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); + + const __wasi_size_t WasiBufLen = BufLen; + const auto Buf = MemInst->getSpan(BufPtr, WasiBufLen); + checkRangeExist(Buf, WasiBufLen); + + if (auto Res = Ctx.optionsSetGuestBuffer(OptionsHandle, Name, Buf); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +SecretsManagerOpen::body(const Runtime::CallingFrame &Frame, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ SecretsManagerHandlePtr) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const OptOptionsHandle = + MemInst->getPointer(OptOptionsHandlePtr); + checkExist(OptOptionsHandle); + + auto *const SecretsManagerHandle = + MemInst->getPointer<__wasi_secrets_manager_t *>(SecretsManagerHandlePtr); + checkExist(SecretsManagerHandle); + + if (auto Res = Ctx.secretsManagerOpen(*OptOptionsHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SecretsManagerHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect SecretsManagerClose::body(const Runtime::CallingFrame &, + int32_t SecretsManagerHandle) { + if (auto Res = Ctx.secretsManagerClose(SecretsManagerHandle); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +SecretsManagerInvalidate::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, uint32_t KeyIdPtr, + uint32_t KeyIdLen, uint64_t Version) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdLen = KeyIdLen; + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdLen); + checkRangeExist(KeyId, WasiKeyIdLen); + + if (auto Res = + Ctx.secretsManagerInvalidate(SecretsManagerHandle, KeyId, Version); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/func.h b/plugins/wasi_crypto/common/func.h new file mode 100644 index 00000000..e546672d --- /dev/null +++ b/plugins/wasi_crypto/common/func.h @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/func.h - Common func ----------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the common host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/hostfunction.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +class ArrayOutputLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t ArrayOutputHandle, uint32_t /* Out */ SizePtr); +}; + +class ArrayOutputPull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t ArrayOutputHandle, uint32_t BufPtr, + uint32_t BufLen, uint32_t /* Out */ SizePtr); +}; + +class OptionsOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t /* Out */ OptionsHandlePtr); +}; + +class OptionsClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle); +}; + +class OptionsSet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t ValuePtr, uint32_t ValueLen); +}; + +class OptionsSetU64 : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint64_t Value); +}; + +class OptionsSetGuestBuffer : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t BufPtr, uint32_t BufLen); +}; + +class SecretsManagerOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ SecretsManagerHandlePtr); +}; + +class SecretsManagerClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle); +}; + +class SecretsManagerInvalidate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, uint32_t KeyIdPtr, + uint32_t KeyIdLen, uint64_t Version); +}; + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/module.cpp b/plugins/wasi_crypto/common/module.cpp new file mode 100644 index 00000000..9cdfd46d --- /dev/null +++ b/plugins/wasi_crypto/common/module.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/module.h" +#include "common/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoCommonModule::WasiCryptoCommonModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_common"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("array_output_len", + std::make_unique(*Ctx)); + addHostFunc("array_output_pull", + std::make_unique(*Ctx)); + addHostFunc("options_open", std::make_unique(*Ctx)); + addHostFunc("options_close", std::make_unique(*Ctx)); + addHostFunc("options_set", std::make_unique(*Ctx)); + addHostFunc("options_set_u64", std::make_unique(*Ctx)); + addHostFunc("options_set_guest_buffer", + std::make_unique(*Ctx)); + addHostFunc("secrets_manager_open", + std::make_unique(*Ctx)); + addHostFunc("secrets_manager_close", + std::make_unique(*Ctx)); + addHostFunc("secrets_manager_invalidate", + std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/module.h b/plugins/wasi_crypto/common/module.h new file mode 100644 index 00000000..75a1bc11 --- /dev/null +++ b/plugins/wasi_crypto/common/module.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/module.h - Common Module ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto common module class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoCommonModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoCommonModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/options.cpp b/plugins/wasi_crypto/common/options.cpp new file mode 100644 index 00000000..e95aefa2 --- /dev/null +++ b/plugins/wasi_crypto/common/options.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/options.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +Options optionsOpen(__wasi_algorithm_type_e_t Alg) noexcept { + switch (Alg) { + case __WASI_ALGORITHM_TYPE_SIGNATURES: + return Options{std::in_place_type}; + case __WASI_ALGORITHM_TYPE_SYMMETRIC: + return Options{std::in_place_type}; + case __WASI_ALGORITHM_TYPE_KEY_EXCHANGE: + return Options{std::in_place_type}; + default: + assumingUnreachable(); + } +} + +WasiCryptoExpect optionsSet(Options &Options, std::string_view Name, + Span Value) noexcept { + return std::visit( + [Name, Value](auto &Option) noexcept { return Option.set(Name, Value); }, + Options); +} + +WasiCryptoExpect optionsSetU64(Options &Options, std::string_view Name, + uint64_t Value) noexcept { + return std::visit( + [Name, Value](auto &Option) noexcept { + return Option.setU64(Name, Value); + }, + Options); +} + +WasiCryptoExpect optionsSetGuestBuffer(Options &Options, + std::string_view Name, + Span Value) noexcept { + return std::visit( + [Name, Value](auto &Option) noexcept { + return Option.setGuestBuffer(Name, Value); + }, + Options); +} + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/options.h b/plugins/wasi_crypto/common/options.h new file mode 100644 index 00000000..2225d513 --- /dev/null +++ b/plugins/wasi_crypto/common/options.h @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/options.h - Options definition ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Options definition of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/options.h" +#include "signatures/options.h" +#include "symmetric/options.h" +#include "wasi_crypto/api.hpp" + +#include "common/span.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +/// Some functions support options. For example, options can be used to access +/// the features that are only relevant to specific ciphers and hash functions. +/// +/// Options are represented as a (key, value) map with string keys. They are +/// attached to a context, such as a cipher state. Applications can set, and +/// also read the value associated with a key in order to either get the default +/// value or obtain the runtime information. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#options +using Options = + std::variant; + +Options optionsOpen(__wasi_algorithm_type_e_t Alg) noexcept; + +/// Set byte vectors. +WasiCryptoExpect optionsSet(Options &Options, std::string_view Name, + Span Value) noexcept; + +/// Set unsigned integers. +WasiCryptoExpect optionsSetU64(Options &Options, std::string_view Name, + uint64_t Value) noexcept; + +/// Set memory buffers. +WasiCryptoExpect optionsSetGuestBuffer(Options &Options, + std::string_view Name, + Span Value) noexcept; +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/ctx.cpp b/plugins/wasi_crypto/ctx.cpp new file mode 100644 index 00000000..fd5f4233 --- /dev/null +++ b/plugins/wasi_crypto/ctx.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ctx.h" +#include "asymmetric_common/module.h" +#include "common/module.h" +#include "kx/module.h" +#include "signatures/module.h" +#include "symmetric/module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +template +Runtime::Instance::ModuleInstance * +createModule(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new T(WasiCrypto::Context::getInstance()); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_crypto", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 5, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasi_crypto_asymmetric_common", + .Description = "", + .Create = createModule, + }, + { + .Name = "wasi_crypto_common", + .Description = "", + .Create = createModule, + }, + { + .Name = "wasi_crypto_kx", + .Description = "", + .Create = createModule, + }, + { + .Name = "wasi_crypto_signatures", + .Description = "", + .Create = createModule, + }, + { + .Name = "wasi_crypto_symmetric", + .Description = "", + .Create = createModule, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace + +std::shared_mutex WasiCrypto::Context::Mutex; +std::weak_ptr WasiCrypto::Context::Instance; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/ctx.h b/plugins/wasi_crypto/ctx.h new file mode 100644 index 00000000..caa5bdf4 --- /dev/null +++ b/plugins/wasi_crypto/ctx.h @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/ctx.h - Context class definition -----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto context. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/keypair.h" +#include "asymmetric_common/publickey.h" +#include "asymmetric_common/secretkey.h" +#include "common/array_output.h" +#include "common/options.h" +#include "kx/registered.h" +#include "signatures/registered.h" +#include "signatures/signatures.h" +#include "signatures/signstate.h" +#include "signatures/verificationstate.h" +#include "symmetric/key.h" +#include "symmetric/registered.h" +#include "symmetric/state.h" +#include "symmetric/tag.h" +#include "utils/error.h" +#include "utils/handles_manager.h" + +#include "common/span.h" +#include "plugin/plugin.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +class Context { +public: + // Singleton + static std::shared_ptr getInstance() noexcept { + std::unique_lock Lock(Mutex); + std::shared_ptr CtxPtr = Instance.lock(); + if (!CtxPtr) { + CtxPtr.reset(new Context()); + Instance = CtxPtr; + } + return CtxPtr; + } + + Context(const Context &) = delete; + void operator=(const Context &) = delete; + + // Common + + WasiCryptoExpect + arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) noexcept; + + WasiCryptoExpect + arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, + Span Buf) noexcept; + + WasiCryptoExpect<__wasi_options_t> + optionsOpen(__wasi_algorithm_type_e_t AlgType) noexcept; + + WasiCryptoExpect optionsClose(__wasi_options_t OptionsHandle) noexcept; + + WasiCryptoExpect optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, + uint64_t Value) noexcept; + + WasiCryptoExpect optionsSetGuestBuffer(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Buf) noexcept; + + WasiCryptoExpect<__wasi_secrets_manager_t> + secretsManagerOpen(__wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect + secretsManagerClose(__wasi_secrets_manager_t SecretsManagerHandle) noexcept; + + WasiCryptoExpect + secretsManagerInvalidate(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, + __wasi_version_t Version) noexcept; + + // Symmetric + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerate(Symmetric::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyImport(Symmetric::Algorithm Alg, + Span Raw) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) noexcept; + + WasiCryptoExpect + symmetricKeyClose(__wasi_symmetric_key_t KeyHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + Symmetric::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect + symmetricKeyStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t KeyHandle, + Span KeyId) noexcept; + + WasiCryptoExpect<__wasi_version_t> + symmetricKeyReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t OldKeyHandle, + __wasi_symmetric_key_t NewKeyHandle) noexcept; + + WasiCryptoExpect> + symmetricKeyId(__wasi_symmetric_key_t KeyHandle, + Span KeyId) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, __wasi_version_t KeyVersion) noexcept; + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateOpen(Symmetric::Algorithm Alg, + __wasi_opt_symmetric_key_t OptKeyHandle, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateClone(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, Span Value) noexcept; + + WasiCryptoExpect + symmetricStateOptionsGetU64(__wasi_symmetric_state_t StateHandle, + std::string_view Name) noexcept; + + WasiCryptoExpect + symmetricStateClose(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data) noexcept; + + WasiCryptoExpect + symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out) noexcept; + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + Symmetric::Algorithm Alg) noexcept; + + WasiCryptoExpect + symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data) noexcept; + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateEncryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept; + + WasiCryptoExpect + symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data) noexcept; + + WasiCryptoExpect + symmetricStateDecryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, Span Data, + Span RawTag) noexcept; + + WasiCryptoExpect + symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricTagLen(__wasi_symmetric_tag_t TagHandle) noexcept; + + WasiCryptoExpect symmetricTagPull(__wasi_symmetric_tag_t TagHandle, + Span Buf) noexcept; + + WasiCryptoExpect + symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag) noexcept; + + WasiCryptoExpect + symmetricTagClose(__wasi_symmetric_tag_t TagHandle) noexcept; + + // Asymmetric_common + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerate(AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairImport(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect + keypairStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t KpHandle, Span KpId) noexcept; + + WasiCryptoExpect<__wasi_version_t> + keypairReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t OldKpHandle, + __wasi_keypair_t NewKpHandle) noexcept; + + WasiCryptoExpect> + keypairId(__wasi_keypair_t KpHandle, Span KpId) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KpId, + __wasi_version_t KpIdVersion) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromPkAndSk(__wasi_publickey_t PkHandle, + __wasi_secretkey_t SkHandle) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_publickey_t> + keypairPublickey(__wasi_keypair_t KpHandle) noexcept; + + WasiCryptoExpect<__wasi_secretkey_t> + keypairSecretkey(__wasi_keypair_t KpHandle) noexcept; + + WasiCryptoExpect keypairClose(__wasi_keypair_t KpHandle) noexcept; + + WasiCryptoExpect<__wasi_publickey_t> + publickeyImport(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect publickeyVerify(__wasi_publickey_t PkHandle) noexcept; + + WasiCryptoExpect<__wasi_publickey_t> + publickeyFromSecretkey(__wasi_secretkey_t SkHandle) noexcept; + + WasiCryptoExpect publickeyClose(__wasi_publickey_t PkHandle) noexcept; + + WasiCryptoExpect<__wasi_secretkey_t> + secretkeyImport(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect secretkeyClose(__wasi_secretkey_t SkHandle) noexcept; + + // Key_exchange + + WasiCryptoExpect<__wasi_array_output_t> + kxDh(__wasi_kx_publickey_t PkHandle, __wasi_kx_secretkey_t SkHandle) noexcept; + + WasiCryptoExpect> + kxEncapsulate(__wasi_kx_publickey_t PkHandle) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret) noexcept; + + // Signature + + WasiCryptoExpect<__wasi_array_output_t> + signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_signature_t> + signatureImport(Signatures::Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect signatureClose(__wasi_signature_t SigHandle) noexcept; + + WasiCryptoExpect<__wasi_signature_state_t> + signatureStateOpen(__wasi_signature_keypair_t KpHandle) noexcept; + + WasiCryptoExpect + signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input) noexcept; + + WasiCryptoExpect<__wasi_signature_t> + signatureStateSign(__wasi_signature_state_t StateHandle) noexcept; + + WasiCryptoExpect + signatureStateClose(__wasi_signature_state_t StateHandle) noexcept; + + WasiCryptoExpect<__wasi_signature_verification_state_t> + signatureVerificationStateOpen( + __wasi_signature_publickey_t PkHandle) noexcept; + + WasiCryptoExpect signatureVerificationStateUpdate( + __wasi_signature_verification_state_t StateHandle, + Span Input) noexcept; + + WasiCryptoExpect signatureVerificationStateVerify( + __wasi_signature_verification_state_t StateHandle, + __wasi_signature_t SigHandle) noexcept; + + WasiCryptoExpect signatureVerificationStateClose( + __wasi_signature_verification_state_t StateHandle) noexcept; + +private: + Context() noexcept {} + + RefHandlesManager<__wasi_array_output_t, Common::ArrayOutput> + ArrayOutputManager{0x00}; + RcHandlesManager<__wasi_options_t, Common::Options> OptionsManager{0x01}; + RefHandlesManager<__wasi_symmetric_tag_t, Symmetric::Tag> SymmetricTagManager{ + 0xa}; + RcHandlesManager<__wasi_symmetric_key_t, Symmetric::KeyVariant> + SymmetricKeyManager{0x09}; + RcHandlesManager<__wasi_symmetric_state_t, Symmetric::StateVariant> + SymmetricStateManager{0x08}; + RcHandlesManager<__wasi_publickey_t, AsymmetricCommon::PkVariant> + PublicKeyManager{0x03}; + RcHandlesManager<__wasi_secretkey_t, AsymmetricCommon::SkVariant> + SecretKeyManager{0x04}; + RcHandlesManager<__wasi_keypair_t, AsymmetricCommon::KpVariant> + KeyPairManager{0x05}; + RcHandlesManager<__wasi_signature_t, Signatures::SigVariant> SignatureManager{ + 0x06}; + RcHandlesManager<__wasi_signature_state_t, Signatures::SignStateVariant> + SignStateManager{0x07}; + RcHandlesManager<__wasi_signature_state_t, + Signatures::VerificationStateVariant> + VerificationStateManager{0x02}; + + static std::shared_mutex Mutex; + static std::weak_ptr Instance; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/ctx.cpp b/plugins/wasi_crypto/kx/ctx.cpp new file mode 100644 index 00000000..86f7646a --- /dev/null +++ b/plugins/wasi_crypto/kx/ctx.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ctx.h" +#include "kx/kx.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect<__wasi_array_output_t> +Context::kxDh(__wasi_kx_publickey_t PkHandle, + __wasi_kx_secretkey_t SkHandle) noexcept { + auto Sk = SecretKeyManager.getAs(SkHandle); + if (!Sk) { + return WasiCryptoUnexpect(Sk); + } + + auto Pk = PublicKeyManager.getAs(PkHandle); + if (!Pk) { + return WasiCryptoUnexpect(Pk); + } + + return Kx::dh(*Pk, *Sk).and_then([this](auto &&Data) { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect> +Context::kxEncapsulate(__wasi_kx_publickey_t PkHandle) noexcept { + auto EncapsulatedSecret = + PublicKeyManager.getAs(PkHandle).and_then( + [](auto &&KxPk) noexcept { return Kx::encapsulate(KxPk); }); + if (!EncapsulatedSecret) { + return WasiCryptoUnexpect(EncapsulatedSecret); + } + + auto SecretHandle = + ArrayOutputManager.registerManager(std::move(EncapsulatedSecret->Secret)); + if (!SecretHandle) { + return WasiCryptoUnexpect(SecretHandle); + } + + auto EncapsulatedSecretHandle = ArrayOutputManager.registerManager( + std::move(EncapsulatedSecret->EncapsulatedSecretData)); + if (!EncapsulatedSecretHandle) { + return WasiCryptoUnexpect(EncapsulatedSecretHandle); + } + + return std::tuple(*SecretHandle, *EncapsulatedSecretHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret) noexcept { + return SecretKeyManager.getAs(SkHandle) + .and_then([EncapsulatedSecret](auto &&KxSk) noexcept { + return Kx::decapsulate(KxSk, EncapsulatedSecret); + }) + .and_then([this](auto &&Secret) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Secret)); + }); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.cpp b/plugins/wasi_crypto/kx/dh/ecdsa.cpp new file mode 100644 index 00000000..a228cfe8 --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/ecdsa.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "kx/dh/ecdsa.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +namespace { +inline const size_t SharedSecretSize = 32; +} // namespace + +template +WasiCryptoExpect +Ecdsa::SecretKey::dh(const PublicKey &Pk) const noexcept { + EvpPkeyCtxPtr SkCtx{EVP_PKEY_CTX_new(this->Ctx.get(), nullptr)}; + opensslCheck(EVP_PKEY_derive_init(SkCtx.get())); + + // Set peer key. + opensslCheck(EVP_PKEY_derive_set_peer(SkCtx.get(), Pk.raw().get())); + + // Generate shared secret. + SecretVec Res(SharedSecretSize); + size_t Size = SharedSecretSize; + ensureOrReturn(EVP_PKEY_derive(SkCtx.get(), Res.data(), &Size), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(Size == SharedSecretSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +template class Ecdsa; +template class Ecdsa; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.h b/plugins/wasi_crypto/kx/dh/ecdsa.h new file mode 100644 index 00000000..ef3648db --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/ecdsa.h @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/dh/ecdsa.h - Ecdsa alg implement --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the ECDSA algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/ecdsa.h" +#include "kx/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +template class Ecdsa { + +public: + class PublicKey; + class SecretKey; + class KeyPair; + using Base = + AsymmetricCommon::Ecdsa; + class PublicKey : public Base::PublicKeyBase { + public: + using Base::PublicKeyBase::PublicKeyBase; + + const auto &raw() const { return this->Ctx; } + }; + + class SecretKey : public Base::SecretKeyBase { + public: + using Base::SecretKeyBase::SecretKeyBase; + + WasiCryptoExpect dh(const PublicKey &Pk) const noexcept; + }; + + class KeyPair : public Base::KeyPairBase { + public: + using Base::KeyPairBase::KeyPairBase; + }; +}; + +using EcdsaP256 = Ecdsa; +using EcdsaP384 = Ecdsa; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/x25519.cpp b/plugins/wasi_crypto/kx/dh/x25519.cpp new file mode 100644 index 00000000..53d87a72 --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/x25519.cpp @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "kx/dh/x25519.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +namespace { +inline const size_t PkSize = 32; +inline const size_t SkSize = 32; +inline const size_t KpSize = 64; +inline const size_t SharedSecretSize = 32; +} // namespace + +WasiCryptoExpect> X25519::PublicKey::exportData( + __wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + std::vector Res(PkSize); + + size_t Size = PkSize; + opensslCheck(EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect X25519::PublicKey::verify() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect X25519::SecretKey::exportData( + __wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + SecretVec Res(SkSize); + + size_t Size = SkSize; + opensslCheck(EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::SecretKey::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +X25519::SecretKey::dh(const PublicKey &Pk) const noexcept { + EvpPkeyCtxPtr SkCtx{EVP_PKEY_CTX_new(Ctx.get(), nullptr)}; + opensslCheck(EVP_PKEY_derive_init(SkCtx.get())); + + // Set peer key. + opensslCheck(EVP_PKEY_derive_set_peer(SkCtx.get(), Pk.raw().get())); + + // Generate shared secret. + SecretVec Res(SharedSecretSize); + size_t Size = SharedSecretSize; + ensureOrReturn(EVP_PKEY_derive(SkCtx.get(), Res.data(), &Size), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(Size == SharedSecretSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +WasiCryptoExpect +X25519::SecretKey::toKeyPair(const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect +X25519::KeyPair::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +X25519::KeyPair::secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect X25519::KeyPair::exportData( + __wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + SecretVec Res(KpSize); + + size_t Size = PkSize; + opensslCheck(EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + Size = SkSize; + opensslCheck( + EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data() + PkSize, &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::PublicKey::import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + EvpPkeyPtr Pk{EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + Encoded.data(), Encoded.size())}; + ensureOrReturn(Pk, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return Pk; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::SecretKey::import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + EvpPkeyPtr Sk{EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, nullptr, + Encoded.data(), Encoded.size())}; + ensureOrReturn(Sk, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return Sk; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::KeyPair::generate(OptionalRef) noexcept { + EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(EVP_PKEY_X25519, nullptr)}; + opensslCheck(EVP_PKEY_keygen_init(Ctx.get())); + + EVP_PKEY *Kp = nullptr; + opensslCheck(EVP_PKEY_keygen(Ctx.get(), &Kp)); + + return EvpPkeyPtr{Kp}; +} + +WasiCryptoExpect +X25519::KeyPair::import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + ensureOrReturn(Encoded.size() == KpSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + // PublicKey can auto generate from SecretKey. + EvpPkeyPtr Sk{EVP_PKEY_new_raw_private_key( + EVP_PKEY_X25519, nullptr, Encoded.data() + PkSize, SkSize)}; + ensureOrReturn(Sk, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Sk; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/x25519.h b/plugins/wasi_crypto/kx/dh/x25519.h new file mode 100644 index 00000000..ad3de81a --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/x25519.h @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/dh/x25519.h - X25519 alg implement ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the X25519 algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +class X25519 { +public: + class PublicKey { + public: + PublicKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect verify() const noexcept; + + const auto &raw() const { return Ctx; } + + private: + SharedEvpPkey Ctx; + }; + + class KeyPair; + + class SecretKey { + public: + SecretKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + SecretKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect dh(const PublicKey &Pk) const noexcept; + + WasiCryptoExpect toKeyPair(const PublicKey &Pk) const noexcept; + + private: + SharedEvpPkey Ctx; + }; + + class KeyPair { + public: + KeyPair(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + KeyPair(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + static WasiCryptoExpect + import(Span Raw, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect secretKey() const noexcept; + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept; + + private: + SharedEvpPkey Ctx; + }; +}; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/func.cpp b/plugins/wasi_crypto/kx/func.cpp new file mode 100644 index 00000000..4649c278 --- /dev/null +++ b/plugins/wasi_crypto/kx/func.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "kx/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +Expect Dh::body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + int32_t SkHandle, + uint32_t /* Out */ SharedSecretPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const SharedSecret = + MemInst->getPointer<__wasi_array_output_t *>(SharedSecretPtr); + checkExist(SharedSecret); + + if (auto Res = Ctx.kxDh(PkHandle, SkHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SharedSecret = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Encapsulate::body(const Runtime::CallingFrame &Frame, + int32_t PkHandle, + uint32_t /* Out */ SecretPtr, + uint32_t /* Out */ EncapsulatedSecretPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const Secret = MemInst->getPointer<__wasi_array_output_t *>(SecretPtr); + checkExist(Secret); + + auto *const EncapsulatedSecret = + MemInst->getPointer<__wasi_array_output_t *>(EncapsulatedSecretPtr); + checkExist(EncapsulatedSecret); + + if (auto Res = Ctx.kxEncapsulate(PkHandle); unlikely(!Res)) { + return Res.error(); + } else { + std::tie(*Secret, *EncapsulatedSecret) = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Decapsulate::body(const Runtime::CallingFrame &Frame, + int32_t SkHandle, + uint32_t EncapsulatedSecretPtr, + uint32_t EncapsulatedSecretLen, + uint32_t /* Out */ SecretPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiEncapsulatedSecretLen = EncapsulatedSecretLen; + const auto EncapsulatedSecret = MemInst->getSpan( + EncapsulatedSecretPtr, WasiEncapsulatedSecretLen); + + checkRangeExist(EncapsulatedSecret, WasiEncapsulatedSecretLen); + + auto *const Secret = MemInst->getPointer<__wasi_array_output_t *>(SecretPtr); + checkExist(Secret); + + if (auto Res = Ctx.kxDecapsulate(SkHandle, EncapsulatedSecret); + unlikely(!Res)) { + return Res.error(); + } else { + *Secret = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/func.h b/plugins/wasi_crypto/kx/func.h new file mode 100644 index 00000000..3f1040c0 --- /dev/null +++ b/plugins/wasi_crypto/kx/func.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/func.h - Key Exchange funcs -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Key Exchange host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/hostfunction.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +class Dh : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + int32_t SkHandle, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class Encapsulate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + uint32_t /* Out */ SecretPtr, + uint32_t /* Out */ EncapsulatedSecretPtr); +}; + +class Decapsulate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t EncapsulatedSecretPtr, + uint32_t EncapsulatedSecretLen, + uint32_t /* Out */ SecretPtr); +}; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/kx.cpp b/plugins/wasi_crypto/kx/kx.cpp new file mode 100644 index 00000000..79d98e9a --- /dev/null +++ b/plugins/wasi_crypto/kx/kx.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "kx/kx.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +namespace { +template struct DhTrait; + +template +struct DhTrait (SkType::*)(const PkType &) + const noexcept> { + using Pk = PkType; +}; +template using PkType = typename DhTrait::Pk; +} // namespace + +WasiCryptoExpect dh(const PkVariant &PkVariant, + const SkVariant &SkVariant) noexcept { + return std::visit( + [](const auto &Pk, + const auto &Sk) noexcept -> WasiCryptoExpect { + using InPkType = std::decay_t; + using ExpectPkType = PkType>; + if constexpr (std::is_same_v) { + return Sk.dh(Pk); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_KEY); + } + }, + PkVariant, SkVariant); +} + +WasiCryptoExpect +encapsulate(PkVariant &PkVariant) noexcept { + return std::visit( + [](auto &&) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + }, + PkVariant); +} + +WasiCryptoExpect> +decapsulate(SkVariant &, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/kx.h b/plugins/wasi_crypto/kx/kx.h new file mode 100644 index 00000000..5004ebe9 --- /dev/null +++ b/plugins/wasi_crypto/kx/kx.h @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/kx.h - Key Exchange related -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the key exchange related functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/registered.h" +#include "utils/error.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +using PkVariant = RegistedAlg::PkVariant; +using SkVariant = RegistedAlg::SkVariant; + +/// Diffie-Hellman based key agreement. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#diffie-hellman-based-key-agreement +WasiCryptoExpect dh(const PkVariant &PkVariant, + const SkVariant &SkVariant) noexcept; + +/// Key encapsulation mechanisms. +/// +/// More detailed +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#key-encapsulation-mechanisms +struct EncapsulatedSecret { + std::vector EncapsulatedSecretData; + std::vector Secret; +}; + +WasiCryptoExpect encapsulate(PkVariant &PkVariant) noexcept; + +WasiCryptoExpect> +decapsulate(SkVariant &SkVariant, + Span EncapsulatedSecret) noexcept; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/module.cpp b/plugins/wasi_crypto/kx/module.cpp new file mode 100644 index 00000000..12dff877 --- /dev/null +++ b/plugins/wasi_crypto/kx/module.cpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "kx/module.h" +#include "kx/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoKxModule::WasiCryptoKxModule(std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_kx"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("kx_dh", std::make_unique(*Ctx)); + addHostFunc("kx_encapsulate", std::make_unique(*Ctx)); + addHostFunc("kx_decapsulate", std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/module.h b/plugins/wasi_crypto/kx/module.h new file mode 100644 index 00000000..aa3ea512 --- /dev/null +++ b/plugins/wasi_crypto/kx/module.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/module.h - Kx Module --------------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto Kx module class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoKxModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoKxModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/options.cpp b/plugins/wasi_crypto/kx/options.cpp new file mode 100644 index 00000000..c2612e24 --- /dev/null +++ b/plugins/wasi_crypto/kx/options.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "kx/options.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +WasiCryptoExpect Options::set(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setU64(std::string_view, uint64_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setGuestBuffer(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::get(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::getU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/options.h b/plugins/wasi_crypto/kx/options.h new file mode 100644 index 00000000..83fb85d3 --- /dev/null +++ b/plugins/wasi_crypto/kx/options.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/options.h - Key exchange Options --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Key Exchange Options class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +class Options { +public: + WasiCryptoExpect set(std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect setU64(std::string_view Name, uint64_t Value) noexcept; + + WasiCryptoExpect setGuestBuffer(std::string_view Name, + Span Buffer) noexcept; + + WasiCryptoExpect get(std::string_view Name, + Span Value) const noexcept; + + WasiCryptoExpect getU64(std::string_view Name) const noexcept; +}; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/registered.h b/plugins/wasi_crypto/kx/registered.h new file mode 100644 index 00000000..6020f3b9 --- /dev/null +++ b/plugins/wasi_crypto/kx/registered.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/registered.h - Registered ---------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register key exchange algorithm definitions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/dh/ecdsa.h" +#include "kx/dh/x25519.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +template struct Registered { + using PkVariant = std::variant; + using SkVariant = std::variant; + using KpVariant = std::variant; + using Variant = std::variant; +}; + +using RegistedAlg = Registered; + +using Algorithm = RegistedAlg::Variant; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/ctx.cpp b/plugins/wasi_crypto/signatures/ctx.cpp new file mode 100644 index 00000000..ea85c965 --- /dev/null +++ b/plugins/wasi_crypto/signatures/ctx.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ctx.h" +#include "signatures/signatures.h" +#include "signatures/signstate.h" +#include "signatures/verificationstate.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect<__wasi_array_output_t> +Context::signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding) noexcept { + return SignatureManager.get(SigHandle) + .and_then([Encoding](auto &&SigVariant) noexcept { + return Signatures::sigExportData( + std::forward(SigVariant), Encoding); + }) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::signatureClose(__wasi_signature_t SigHandle) noexcept { + return SignatureManager.close(SigHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::signatureImport(Signatures::Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + return Signatures::sigImport(Alg, Encoded, Encoding) + .and_then([this](auto &&Sig) noexcept { + return SignatureManager.registerManager( + std::forward(Sig)); + }); +} + +WasiCryptoExpect<__wasi_signature_state_t> +Context::signatureStateOpen(__wasi_signature_keypair_t KpHandle) noexcept { + return KeyPairManager.getAs(KpHandle) + .and_then([](auto &&KpVariant) noexcept { + return Signatures::sigStateOpen( + std::forward(KpVariant)); + }) + .and_then([this](auto &&SignStateVariant) noexcept { + return SignStateManager.registerManager( + std::forward(SignStateVariant)); + }); +} + +WasiCryptoExpect +Context::signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input) noexcept { + return SignStateManager.get(StateHandle) + .and_then([Input](auto &&SignStateVariant) noexcept { + return Signatures::sigStateUpdate(SignStateVariant, Input); + }); +} + +WasiCryptoExpect<__wasi_signature_t> +Context::signatureStateSign(__wasi_signature_state_t StateHandle) noexcept { + return SignStateManager.get(StateHandle) + .and_then([](auto &&SignStateVariant) noexcept { + return Signatures::sigStateSign(SignStateVariant); + }) + .and_then([this](auto &&Signature) noexcept { + return SignatureManager.registerManager( + std::forward(Signature)); + }); +} + +WasiCryptoExpect +Context::signatureStateClose(__wasi_signature_state_t StateHandle) noexcept { + return SignStateManager.close(StateHandle); +} + +WasiCryptoExpect<__wasi_signature_verification_state_t> +Context::signatureVerificationStateOpen( + __wasi_signature_publickey_t PkHandle) noexcept { + return PublicKeyManager.getAs(PkHandle) + .and_then([](auto &&PkVariant) noexcept { + return Signatures::verificationStateOpen( + std::forward(PkVariant)); + }) + .and_then([this](auto &&VerificationStateVariant) noexcept { + return VerificationStateManager.registerManager( + std::forward( + VerificationStateVariant)); + }); +} + +WasiCryptoExpect Context::signatureVerificationStateUpdate( + __wasi_signature_verification_state_t VerificationHandle, + Span Input) noexcept { + return VerificationStateManager.get(VerificationHandle) + .and_then([Input](auto &&VerificationStateVariant) noexcept { + return Signatures::verificationStateUpdate(VerificationStateVariant, + Input); + }); +} + +WasiCryptoExpect Context::signatureVerificationStateVerify( + __wasi_signature_verification_state_t VerificationHandle, + __wasi_signature_t SigHandle) noexcept { + auto Verification = VerificationStateManager.get(VerificationHandle); + if (!Verification) { + return WasiCryptoUnexpect(Verification); + } + + auto Sig = SignatureManager.get(SigHandle); + if (!Sig) { + return WasiCryptoUnexpect(Sig); + } + + return Signatures::verificationStateVerify(*Verification, *Sig); +} + +WasiCryptoExpect Context::signatureVerificationStateClose( + __wasi_signature_verification_state_t VerificationHandle) noexcept { + return VerificationStateManager.close(VerificationHandle); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/ecdsa.cpp b/plugins/wasi_crypto/signatures/ecdsa.cpp new file mode 100644 index 00000000..79c53f1d --- /dev/null +++ b/plugins/wasi_crypto/signatures/ecdsa.cpp @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/ecdsa.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +namespace { +inline const size_t RawSigSize = 64; +} // namespace + +template +WasiCryptoExpect::VerificationState> +Ecdsa::PublicKey::openVerificationState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestVerifyInit(SignCtx.get(), nullptr, EVP_sha256(), + nullptr, this->Ctx.get())); + return SignCtx; +} + +template +WasiCryptoExpect::SignState> +Ecdsa::KeyPair::openSignState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestSignInit(SignCtx.get(), nullptr, EVP_sha256(), nullptr, + this->Ctx.get())); + + return SignCtx; +} + +template +WasiCryptoExpect::Signature> +Ecdsa::Signature::import( + Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: { + ensureOrReturn(Encoded.size() == RawSigSize, + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + EcdsaSigPtr Sig{o2iEcdsaSig(Encoded)}; + ensureOrReturn(Sig, __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + return i2dEcdsaSig(Sig.get()); + } + case __WASI_SIGNATURE_ENCODING_DER: { + return std::vector(Encoded.begin(), Encoded.end()); + } + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect> Ecdsa::Signature::exportData( + __wasi_signature_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: { + EcdsaSigPtr Sig{d2iEcdsaSig(Data)}; + ensureOrReturn(Sig, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return i2oEcdsaSig(Sig.get()); + } + case __WASI_SIGNATURE_ENCODING_DER: { + return Data; + } + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect +Ecdsa::SignState::update(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestSignUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect::Signature> +Ecdsa::SignState::sign() noexcept { + size_t Size; + // For ECDSA, OpenSSL produces DER-formatted signatures, which means the size + // is not fixed. For more context, see: + // https://bitcoin.stackexchange.com/questions/77191/what-is-the-maximum-size-of-a-der-encoded-ecdsa-signature + // So instead of fixing the size, just read it. + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestSignFinal(Ctx->RawCtx.get(), nullptr, &Size)); + + std::vector Res(Size); + opensslCheck(EVP_DigestSignFinal(Ctx->RawCtx.get(), Res.data(), &Size)); + + return Res; +} + +template +WasiCryptoExpect +Ecdsa::VerificationState::update(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestVerifyUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect +Ecdsa::VerificationState::verify(const Signature &Sig) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + ensureOrReturn(EVP_DigestVerifyFinal(Ctx->RawCtx.get(), Sig.ref().data(), + Sig.ref().size()), + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED); + return {}; +} + +template class Ecdsa; +template class Ecdsa; +template class Ecdsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/ecdsa.h b/plugins/wasi_crypto/signatures/ecdsa.h new file mode 100644 index 00000000..f3d6d35f --- /dev/null +++ b/plugins/wasi_crypto/signatures/ecdsa.h @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/ecdsa.h - Ecdsa alg -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the ECDSA algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/ecdsa.h" +#include "signatures/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template class Ecdsa { +public: + class PublicKey; + class KeyPair; + class SecretKey; + using Base = + AsymmetricCommon::Ecdsa; + class Signature { + public: + Signature(std::vector Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_signature_encoding_e_t Encoding) const noexcept; + + const std::vector &ref() const { return Data; } + + private: + // Inner represent as der because OpenSSL use der format for evp interface. + std::vector Data; + }; + + class SignState { + public: + SignState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect sign() noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class VerificationState { + public: + VerificationState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect verify(const Signature &Sig) noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class PublicKey : public Base::PublicKeyBase { + public: + using Base::PublicKeyBase::PublicKeyBase; + + WasiCryptoExpect openVerificationState() const noexcept; + }; + + class SecretKey : public Base::SecretKeyBase { + public: + using Base::SecretKeyBase::SecretKeyBase; + }; + + class KeyPair : public Base::KeyPairBase { + public: + using Base::KeyPairBase::KeyPairBase; + + WasiCryptoExpect openSignState() const noexcept; + }; +}; + +using EcdsaP256 = Ecdsa; +using EcdsaK256 = Ecdsa; +using EcdsaP384 = Ecdsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/eddsa.cpp b/plugins/wasi_crypto/signatures/eddsa.cpp new file mode 100644 index 00000000..f3d36297 --- /dev/null +++ b/plugins/wasi_crypto/signatures/eddsa.cpp @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/eddsa.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +namespace { +inline constexpr size_t PkSize = 32; +inline constexpr size_t SkSize = 32; +inline constexpr size_t KpSize = 64; +inline constexpr size_t SigSize = 64; +} // namespace + +WasiCryptoExpect +Eddsa::PublicKey::import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + EvpPkeyPtr Ctx{EVP_PKEY_new_raw_public_key(EVP_PKEY_ED25519, nullptr, + Encoded.data(), Encoded.size())}; + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect Eddsa::PublicKey::verify() const noexcept { + EvpPkeyCtxPtr CheckCtx{EVP_PKEY_CTX_new(Ctx.get(), nullptr)}; + ensureOrReturn(CheckCtx, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + // EVP_PKEY_public_check() returns 1 for a valid key and 0 for an invalid + // one. A negative value means the check is unsupported for this key type + // (e.g. Ed25519 on OpenSSL 1.1.1), so only an explicit 0 is treated as + // invalid. + ensureOrReturn(EVP_PKEY_public_check(CheckCtx.get()) != 0, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {}; +} + +WasiCryptoExpect> Eddsa::PublicKey::exportData( + __wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + size_t Size = PkSize; + std::vector Res(PkSize); + opensslCheck(EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::PublicKey::openVerificationState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + + opensslCheck(EVP_DigestVerifyInit(SignCtx.get(), nullptr, nullptr, nullptr, + Ctx.get())); + return SignCtx; +} + +WasiCryptoExpect +Eddsa::SecretKey::import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + EvpPkeyPtr Ctx{EVP_PKEY_new_raw_private_key( + EVP_PKEY_ED25519, nullptr, Encoded.data(), Encoded.size())}; + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::SecretKey::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +Eddsa::SecretKey::toKeyPair(const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect Eddsa::SecretKey::exportData( + __wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + size_t Size = SkSize; + SecretVec Res(SkSize); + opensslCheck(EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::KeyPair::generate(OptionalRef) noexcept { + // Generate Key. + EvpPkeyCtxPtr KCtx{EVP_PKEY_CTX_new_id(EVP_PKEY_ED25519, nullptr)}; + opensslCheck(KCtx); + opensslCheck(EVP_PKEY_keygen_init(KCtx.get())); + + EVP_PKEY *Key = nullptr; + opensslCheck(EVP_PKEY_keygen(KCtx.get(), &Key)); + + return EvpPkeyPtr{Key}; +} + +// Refer to: https://github.com/openssl/openssl/issues/8960 +WasiCryptoExpect +Eddsa::KeyPair::import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + ensureOrReturn(Encoded.size() == KpSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + // PublicKey can auto generate from SecretKey. + EvpPkeyPtr SkCtx{EVP_PKEY_new_raw_private_key(EVP_PKEY_ED25519, nullptr, + Encoded.data(), SkSize)}; + ensureOrReturn(SkCtx, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return SkCtx; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect Eddsa::KeyPair::exportData( + __wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + SecretVec Res(KpSize); + size_t Size = SkSize; + opensslCheck(EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + Size = PkSize; + opensslCheck( + EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data() + SkSize, &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect Eddsa::KeyPair::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect Eddsa::KeyPair::secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +Eddsa::KeyPair::openSignState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(SignCtx); + + opensslCheck( + EVP_DigestSignInit(SignCtx.get(), nullptr, nullptr, nullptr, Ctx.get())); + + return SignCtx; +} + +WasiCryptoExpect +Eddsa::Signature::import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + ensureOrReturn(Encoded.size() == SigSize, + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + return std::vector(Encoded.begin(), Encoded.end()); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect> Eddsa::Signature::exportData( + __wasi_signature_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + return Data; + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::SignState::update(Span Input) noexcept { + // Notice: EdDSA is one-shot in OpenSSL, so we need a cache for updating + // instead of calling `EVP_DigestSignUpdate`. + std::scoped_lock Lock{Ctx->Mutex}; + + Ctx->Data.insert(Ctx->Data.end(), Input.begin(), Input.end()); + return {}; +} + +WasiCryptoExpect Eddsa::SignState::sign() noexcept { + size_t Size = SigSize; + std::vector Res(Size); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestSign(Ctx->RawCtx.get(), Res.data(), &Size, + Ctx->Data.data(), Ctx->Data.size())); + ensureOrReturn(Size == SigSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; +} + +WasiCryptoExpect +Eddsa::VerificationState::update(Span Input) noexcept { + // Also oneshot. + std::scoped_lock Lock{Ctx->Mutex}; + + Ctx->Data.insert(Ctx->Data.end(), Input.begin(), Input.end()); + return {}; +} + +WasiCryptoExpect +Eddsa::VerificationState::verify(const Signature &Sig) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + // The invocation to EVP_DigestVerifyFinal() internally finalizes a copy of + // the digest context. This means that EVP_VerifyUpdate() and + // EVP_VerifyFinal() can be called later to digest and verify the additional + // data. + ensureOrReturn(EVP_DigestVerify(Ctx->RawCtx.get(), Sig.ref().data(), + Sig.ref().size(), Ctx->Data.data(), + Ctx->Data.size()), + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED); + + return {}; +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/eddsa.h b/plugins/wasi_crypto/signatures/eddsa.h new file mode 100644 index 00000000..3ba516dc --- /dev/null +++ b/plugins/wasi_crypto/signatures/eddsa.h @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/eddsa.h - Eddsa alg -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the EdDSA algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +class Eddsa { +public: + class Signature { + public: + Signature(std::vector &&Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_signature_encoding_e_t Encoding) const noexcept; + + const std::vector &ref() const { return Data; } + + private: + std::vector Data; + }; + + class SignState { + public: + SignState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect sign() noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} + std::mutex Mutex; + std::vector Data; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class VerificationState { + public: + VerificationState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect verify(const Signature &Sig) noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} + std::mutex Mutex; + std::vector Data; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class KeyPair; + + class PublicKey { + public: + PublicKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect verify() const noexcept; + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect openVerificationState() const noexcept; + + private: + SharedEvpPkey Ctx; + }; + + class SecretKey { + public: + SecretKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + SecretKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect toKeyPair(const PublicKey &Pk) const noexcept; + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept; + + private: + SharedEvpPkey Ctx; + }; + + class KeyPair { + public: + KeyPair(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + KeyPair(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + static WasiCryptoExpect + import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect secretKey() const noexcept; + + WasiCryptoExpect openSignState() const noexcept; + + private: + SharedEvpPkey Ctx; + }; +}; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/func.cpp b/plugins/wasi_crypto/signatures/func.cpp new file mode 100644 index 00000000..18e6905c --- /dev/null +++ b/plugins/wasi_crypto/signatures/func.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +Expect Export::body(const Runtime::CallingFrame &Frame, + int32_t SigHandle, uint32_t Encoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + __wasi_signature_encoding_e_t WasiEncoding; + if (auto Res = cast<__wasi_signature_encoding_e_t>(Encoding); + unlikely(!Res)) { + return Res.error(); + } else { + WasiEncoding = *Res; + } + + auto *const ArrayOutput = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutput); + + if (auto Res = Ctx.signatureExport(SigHandle, WasiEncoding); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutput = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Import::body(const Runtime::CallingFrame &Frame, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t EncodedPtr, uint32_t EncodedLen, + uint32_t Encoding, + uint32_t /* Out */ SigHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + Algorithm WasiAlg; + if (auto Res = tryFrom(Alg); unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); + + __wasi_signature_encoding_e_t WasiEncoding; + if (auto Res = cast<__wasi_signature_encoding_e_t>(Encoding); + unlikely(!Res)) { + return Res.error(); + } else { + WasiEncoding = *Res; + } + + auto *const SigHandle = + MemInst->getPointer<__wasi_signature_t *>(SigHandlePtr); + checkExist(SigHandle); + + if (auto Res = Ctx.signatureImport(WasiAlg, Encoded, WasiEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *SigHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateOpen::body(const Runtime::CallingFrame &Frame, + int32_t KpHandle, + uint32_t /* Out */ SigStatePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const SigState = + MemInst->getPointer<__wasi_signature_state_t *>(SigStatePtr); + checkExist(SigState); + + if (auto Res = Ctx.signatureStateOpen(KpHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SigState = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateUpdate::body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiInputSize = InputSize; + const auto Input = MemInst->getSpan(InputPtr, WasiInputSize); + checkRangeExist(Input, WasiInputSize); + + if (auto Res = Ctx.signatureStateUpdate(SigStateHandle, Input); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateSign::body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle, + uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.signatureStateSign(SigStateHandle); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateClose::body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + if (auto Res = Ctx.signatureStateClose(SigStateHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +VerificationStateOpen::body(const Runtime::CallingFrame &Frame, + int32_t SigPkHandle, + uint32_t /* Out */ VerificationStateHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const VerificationStateHandle = + MemInst->getPointer<__wasi_signature_state_t *>( + VerificationStateHandlePtr); + checkExist(VerificationStateHandle); + + if (auto Res = Ctx.signatureVerificationStateOpen(SigPkHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *VerificationStateHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +VerificationStateUpdate::body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiInputSize = InputSize; + const auto Input = MemInst->getSpan(InputPtr, WasiInputSize); + checkRangeExist(Input, WasiInputSize); + + if (auto Res = Ctx.signatureVerificationStateUpdate(SigStateHandle, Input); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect VerificationStateVerify::body(const Runtime::CallingFrame &, + int32_t VerificationStateHandle, + int32_t SigHandle) { + if (auto Res = Ctx.signatureVerificationStateVerify(VerificationStateHandle, + SigHandle); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect VerificationStateClose::body(const Runtime::CallingFrame &, + int32_t VerificationStateHandle) { + if (auto Res = Ctx.signatureVerificationStateClose(VerificationStateHandle); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Close::body(const Runtime::CallingFrame &, int32_t SigHandle) { + if (auto Res = Ctx.signatureClose(SigHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/func.h b/plugins/wasi_crypto/signatures/func.h new file mode 100644 index 00000000..f1b8b8a7 --- /dev/null +++ b/plugins/wasi_crypto/signatures/func.h @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/func.h - Signatures func --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the signatures host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/hostfunction.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +class Export : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t SigHandle, + uint32_t Encoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class Import : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ SigHandlePtr); +}; + +class StateOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t /* Out */ SigStatePtr); +}; + +class StateUpdate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize); +}; + +class StateSign : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class StateClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle); +}; + +class VerificationStateOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t SigPkHandle, + uint32_t /* Out */ VerificationStateHandlePtr); +}; + +class VerificationStateUpdate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize); +}; + +class VerificationStateVerify : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t VerificationStateHandle, int32_t SigHandle); +}; + +class VerificationStateClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t VerificationStateHandle); +}; + +class Close : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t SigHandle); +}; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/module.cpp b/plugins/wasi_crypto/signatures/module.cpp new file mode 100644 index 00000000..36d08b2a --- /dev/null +++ b/plugins/wasi_crypto/signatures/module.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "asymmetric_common/func.h" +#include "common/func.h" +#include "kx/func.h" +#include "signatures/func.h" +#include "symmetric/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoSignaturesModule::WasiCryptoSignaturesModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_signatures"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("signature_export", std::make_unique(*Ctx)); + addHostFunc("signature_import", std::make_unique(*Ctx)); + addHostFunc("signature_state_open", + std::make_unique(*Ctx)); + addHostFunc("signature_state_update", + std::make_unique(*Ctx)); + addHostFunc("signature_state_sign", + std::make_unique(*Ctx)); + addHostFunc("signature_state_close", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_open", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_update", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_verify", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_close", + std::make_unique(*Ctx)); + addHostFunc("signature_close", std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/module.h b/plugins/wasi_crypto/signatures/module.h new file mode 100644 index 00000000..1296fff4 --- /dev/null +++ b/plugins/wasi_crypto/signatures/module.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/module.h - Module ---------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto signatures module +/// class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoSignaturesModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoSignaturesModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/options.cpp b/plugins/wasi_crypto/signatures/options.cpp new file mode 100644 index 00000000..d4de6cf3 --- /dev/null +++ b/plugins/wasi_crypto/signatures/options.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/options.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect Options::set(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setU64(std::string_view, uint64_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setGuestBuffer(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::get(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::getU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/options.h b/plugins/wasi_crypto/signatures/options.h new file mode 100644 index 00000000..895c5d9f --- /dev/null +++ b/plugins/wasi_crypto/signatures/options.h @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/options.h - Options -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Signatures Options class definition. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +class Options { +public: + WasiCryptoExpect set(std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect setU64(std::string_view Name, uint64_t Value) noexcept; + + WasiCryptoExpect setGuestBuffer(std::string_view Name, + Span Buffer) noexcept; + + WasiCryptoExpect get(std::string_view Name, + Span Value) const noexcept; + + WasiCryptoExpect getU64(std::string_view Name) const noexcept; +}; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/registered.h b/plugins/wasi_crypto/signatures/registered.h new file mode 100644 index 00000000..06dcd0d0 --- /dev/null +++ b/plugins/wasi_crypto/signatures/registered.h @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/registered.h - Registered +//-----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register signature algorithm definitions. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "signatures/ecdsa.h" +#include "signatures/eddsa.h" +#include "signatures/rsa.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template struct Registered { + using PkVariant = std::variant; + using SkVariant = std::variant; + using KpVariant = std::variant; + using Variant = std::variant; + using SigVariant = std::variant; + using SignStateVariant = std::variant; + using VerificationStateVariant = + std::variant; +}; + +using RegistedAlg = + Registered; + +using Algorithm = RegistedAlg::Variant; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/rsa.cpp b/plugins/wasi_crypto/signatures/rsa.cpp new file mode 100644 index 00000000..722c40a4 --- /dev/null +++ b/plugins/wasi_crypto/signatures/rsa.cpp @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/rsa.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template +WasiCryptoExpect::PublicKey> +Rsa::PublicKey::import( + Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_PUBLICKEY_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::PublicKey> +Rsa::PublicKey::importPkcs8( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPUBKEY(Encoded)}); +} + +template +WasiCryptoExpect::PublicKey> +Rsa::PublicKey::importPem( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPUBKEY(Encoded)}); +} + +template +WasiCryptoExpect +Rsa::PublicKey::checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + const RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; +} + +template +WasiCryptoExpect +Rsa::PublicKey::verify() const noexcept { + ensureOrReturn(RSA_check_key(EVP_PKEY_get0_RSA(Ctx.get())), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {}; +} + +template +WasiCryptoExpect> +Rsa::PublicKey::exportData( + __wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_PUBLICKEY_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect> +Rsa::PublicKey::exportPkcs8() const noexcept { + return i2dPUBKEY(Ctx.get()); +} + +template +WasiCryptoExpect> +Rsa::PublicKey::exportPem() const noexcept { + return pemWritePUBKEY(Ctx.get()); +} + +template +WasiCryptoExpect::VerificationState> +Rsa::PublicKey::openVerificationState() + const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestVerifyInit( + SignCtx.get(), nullptr, EVP_get_digestbynid(ShaNid), nullptr, Ctx.get())); + opensslCheck(EVP_PKEY_CTX_set_rsa_padding(EVP_MD_CTX_pkey_ctx(SignCtx.get()), + PadMode)); + return SignCtx; +} + +template +WasiCryptoExpect::SecretKey> +Rsa::SecretKey::import( + Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_SECRETKEY_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::SecretKey> +Rsa::SecretKey::importPem( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect::SecretKey> +Rsa::SecretKey::importPkcs8( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect +Rsa::SecretKey::checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + const RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; +} + +template +WasiCryptoExpect::KeyPair> +Rsa::SecretKey::toKeyPair( + const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +template +WasiCryptoExpect +Rsa::SecretKey::exportData( + __wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_SECRETKEY_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::PublicKey> +Rsa::SecretKey::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +template +WasiCryptoExpect +Rsa::SecretKey::exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); +} + +template +WasiCryptoExpect +Rsa::SecretKey::exportPkcs8() const noexcept { + EVP_PKEY *Key = Ctx.get(); + BioPtr Bio{BIO_new(BIO_s_mem())}; + + opensslCheck(i2d_PKCS8PrivateKey_bio(Bio.get(), Key, nullptr, nullptr, 0, + nullptr, nullptr)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + SecretVec Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::import( + Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_KEYPAIR_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::importPem( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::importPkcs8( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect +Rsa::KeyPair::checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + const RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; +} + +template +WasiCryptoExpect::SignState> +Rsa::KeyPair::openSignState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestSignInit( + SignCtx.get(), nullptr, EVP_get_digestbynid(ShaNid), nullptr, Ctx.get())); + opensslCheck(EVP_PKEY_CTX_set_rsa_padding(EVP_MD_CTX_pkey_ctx(SignCtx.get()), + PadMode)); + return SignCtx; +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::generate( + OptionalRef) noexcept { + const auto Id = + PadMode == RSA_PKCS1_PADDING ? EVP_PKEY_RSA : EVP_PKEY_RSA_PSS; + EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(Id, nullptr)}; + EVP_PKEY_keygen_init(Ctx.get()); + EVP_PKEY_CTX_set_rsa_keygen_bits(Ctx.get(), KeyBits); + EVP_PKEY_CTX_set_signature_md(Ctx.get(), getShaCtx()); + + EVP_PKEY *Res = nullptr; + EVP_PKEY_keygen(Ctx.get(), &Res); + return EvpPkeyPtr{Res}; +} + +template +WasiCryptoExpect Rsa::KeyPair::exportData( + __wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_KEYPAIR_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect +Rsa::KeyPair::exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); +} + +template +WasiCryptoExpect +Rsa::KeyPair::exportPkcs8() const noexcept { + return i2dPrivateKey(Ctx.get()); +} + +template +WasiCryptoExpect::PublicKey> +Rsa::KeyPair::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +template +WasiCryptoExpect::SecretKey> +Rsa::KeyPair::secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +template +WasiCryptoExpect::Signature> +Rsa::Signature::import( + Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + ensureOrReturn(Encoded.size() == getSigSize(), + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + return std::vector(Encoded.begin(), Encoded.end()); + case __WASI_SIGNATURE_ENCODING_DER: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect> +Rsa::Signature::exportData( + __wasi_signature_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + return Data; + case __WASI_SIGNATURE_ENCODING_DER: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect Rsa::SignState::update( + Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestSignUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect::Signature> +Rsa::SignState::sign() noexcept { + size_t Size = getSigSize(); + std::vector Res(Size); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestSignFinal(Ctx->RawCtx.get(), Res.data(), &Size)); + ensureOrReturn(Size == getSigSize(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +template +WasiCryptoExpect Rsa::VerificationState::update( + Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestVerifyUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect Rsa::VerificationState::verify( + const Signature &Sig) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + ensureOrReturn(EVP_DigestVerifyFinal(Ctx->RawCtx.get(), Sig.ref().data(), + Sig.ref().size()), + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED); + + return {}; +} + +template class Rsa; +template class Rsa; +template class Rsa; + +template class Rsa; +template class Rsa; + +template class Rsa; + +template class Rsa; +template class Rsa; +template class Rsa; + +template class Rsa; +template class Rsa; + +template class Rsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/rsa.h b/plugins/wasi_crypto/signatures/rsa.h new file mode 100644 index 00000000..81631080 --- /dev/null +++ b/plugins/wasi_crypto/signatures/rsa.h @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/rsa.h - Rsa alg implement -===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of RSA and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template class Rsa { +public: + class Signature { + public: + Signature(std::vector &&Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_signature_encoding_e_t Encoding) const noexcept; + + const std::vector &ref() const { return Data; } + + private: + std::vector Data; + }; + + class SignState { + public: + SignState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Data) noexcept; + + WasiCryptoExpect sign() noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class VerificationState { + public: + VerificationState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Data) noexcept; + + WasiCryptoExpect verify(const Signature &Sig) noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class PublicKey { + public: + PublicKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect verify() const noexcept; + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect openVerificationState() const noexcept; + + private: + static WasiCryptoExpect + importPem(Span Encoded) noexcept; + + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept; + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept; + + WasiCryptoExpect> exportPem() const noexcept; + + WasiCryptoExpect> exportPkcs8() const noexcept; + + SharedEvpPkey Ctx; + }; + + class KeyPair; + + class SecretKey { + public: + SecretKey(EvpPkeyPtr Ctx) : Ctx(std::move(Ctx)) {} + + SecretKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect toKeyPair(const PublicKey &Pk) const noexcept; + + private: + static WasiCryptoExpect + importPem(Span Encoded) noexcept; + + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept; + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept; + + WasiCryptoExpect exportPem() const noexcept; + + WasiCryptoExpect exportPkcs8() const noexcept; + + SharedEvpPkey Ctx; + }; + + class KeyPair { + public: + KeyPair(EvpPkeyPtr Ctx) : Ctx(std::move(Ctx)) {} + + KeyPair(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + static WasiCryptoExpect + generate(OptionalRef OptOptions) noexcept; + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect secretKey() const noexcept; + + WasiCryptoExpect openSignState() const noexcept; + + private: + static WasiCryptoExpect + importPem(Span Encoded) noexcept; + + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept; + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept; + + WasiCryptoExpect exportPem() const noexcept; + + WasiCryptoExpect exportPkcs8() const noexcept; + + SharedEvpPkey Ctx; + }; + +private: + static constexpr size_t getSigSize() { return KeyBits / 8; } + + static const EVP_MD *getShaCtx() { return EVP_get_digestbynid(ShaNid); } +}; + +using RSA_PKCS1_2048_SHA256 = Rsa; +using RSA_PKCS1_2048_SHA384 = Rsa; +using RSA_PKCS1_2048_SHA512 = Rsa; + +using RSA_PKCS1_3072_SHA384 = Rsa; +using RSA_PKCS1_3072_SHA512 = Rsa; + +using RSA_PKCS1_4096_SHA512 = Rsa; + +using RSA_PSS_2048_SHA256 = Rsa; +using RSA_PSS_2048_SHA384 = Rsa; +using RSA_PSS_2048_SHA512 = Rsa; + +using RSA_PSS_3072_SHA384 = Rsa; +using RSA_PSS_3072_SHA512 = Rsa; + +using RSA_PSS_4096_SHA512 = Rsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signatures.cpp b/plugins/wasi_crypto/signatures/signatures.cpp new file mode 100644 index 00000000..ae3b003e --- /dev/null +++ b/plugins/wasi_crypto/signatures/signatures.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/signatures.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect +sigImport(Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept { + return decltype(Factory)::Signature::import(Encoded, Encoding) + .map([](auto &&Sig) noexcept { + return SigVariant{std::forward(Sig)}; + }); + }, + Alg); +} + +WasiCryptoExpect> +sigExportData(const SigVariant &SigVariant, + __wasi_signature_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](auto &Sig) noexcept { return Sig.exportData(Encoding); }, + SigVariant); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signatures.h b/plugins/wasi_crypto/signatures/signatures.h new file mode 100644 index 00000000..d2bd2bbb --- /dev/null +++ b/plugins/wasi_crypto/signatures/signatures.h @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/signatures.h - Signatures -===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the implementation of signatures. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/registered.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +using SigVariant = RegistedAlg::SigVariant; + +WasiCryptoExpect +sigImport(Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect> +sigExportData(const SigVariant &SigVariant, + __wasi_signature_encoding_e_t Encoding) noexcept; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signstate.cpp b/plugins/wasi_crypto/signatures/signstate.cpp new file mode 100644 index 00000000..a8c45707 --- /dev/null +++ b/plugins/wasi_crypto/signatures/signstate.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/signstate.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect +sigStateOpen(const KpVariant &KpVariant) noexcept { + return std::visit( + [](const auto &Kp) noexcept { + return Kp.openSignState().map([](auto &&SignState) noexcept { + return SignStateVariant{std::forward(SignState)}; + }); + }, + KpVariant); +} + +WasiCryptoExpect sigStateUpdate(SignStateVariant &SignStateVariant, + Span Input) noexcept { + return std::visit( + [Input](auto &SignState) noexcept { return SignState.update(Input); }, + SignStateVariant); +} + +WasiCryptoExpect +sigStateSign(SignStateVariant &SignStateVariant) noexcept { + return std::visit( + [](auto &SignState) noexcept { + return SignState.sign().map([](auto &&Sig) noexcept { + return SigVariant{std::forward(Sig)}; + }); + }, + SignStateVariant); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signstate.h b/plugins/wasi_crypto/signatures/signstate.h new file mode 100644 index 00000000..6cc733ae --- /dev/null +++ b/plugins/wasi_crypto/signatures/signstate.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/signstate.h - SignState ---===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the implementation of sign state. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/registered.h" +#include "signatures/signatures.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +/// Signature computation. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#signature-creation +using SignStateVariant = RegistedAlg::SignStateVariant; +using KpVariant = RegistedAlg::KpVariant; + +WasiCryptoExpect +sigStateOpen(const KpVariant &PkVariant) noexcept; + +WasiCryptoExpect sigStateUpdate(SignStateVariant &SignStateVariant, + Span Input) noexcept; + +WasiCryptoExpect +sigStateSign(SignStateVariant &SignStateVariant) noexcept; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/verificationstate.cpp b/plugins/wasi_crypto/signatures/verificationstate.cpp new file mode 100644 index 00000000..3f4593ef --- /dev/null +++ b/plugins/wasi_crypto/signatures/verificationstate.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "signatures/verificationstate.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect +verificationStateOpen(const PkVariant &PkVariant) noexcept { + return std::visit( + [](const auto &Pk) noexcept { + return Pk.openVerificationState().map( + [](auto &&VerificationState) noexcept { + return VerificationStateVariant{ + std::forward(VerificationState)}; + }); + }, + PkVariant); +} + +WasiCryptoExpect +verificationStateUpdate(VerificationStateVariant &VerificationStateVariant, + Span Input) noexcept { + return std::visit( + [Input](auto &VerificationState) noexcept { + return VerificationState.update(Input); + }, + VerificationStateVariant); +} + +namespace { +/// Correspond signatures: +/// WasiCryptoExpect VerificationStateType::verify(const SigType&); +/// is used to get `SigType`. +template struct VerifyTrait; +template +struct VerifyTrait (VerificationStateType::*)( + const SigType &) noexcept> { + using Sig = SigType; +}; + +template +using SigType = typename VerifyTrait::Sig; +} // namespace + +WasiCryptoExpect +verificationStateVerify(VerificationStateVariant &VerificationStateVariant, + const SigVariant &SigVariant) noexcept { + return std::visit( + [](auto &VerificationState, + const auto &Sig) noexcept -> WasiCryptoExpect { + using RequiredSigType = + SigType>; + using InSigType = std::decay_t; + + if constexpr (std::is_same_v) { + return VerificationState.verify(Sig); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + } + }, + VerificationStateVariant, SigVariant); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/verificationstate.h b/plugins/wasi_crypto/signatures/verificationstate.h new file mode 100644 index 00000000..81e612f1 --- /dev/null +++ b/plugins/wasi_crypto/signatures/verificationstate.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/verificationstate.h -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the verification state related functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/registered.h" +#include "signatures/signatures.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +/// Signatures verify +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#signature-verification +using VerificationStateVariant = RegistedAlg::VerificationStateVariant; +using PkVariant = RegistedAlg::PkVariant; + +WasiCryptoExpect +verificationStateOpen(const PkVariant &PkVariant) noexcept; + +WasiCryptoExpect +verificationStateUpdate(VerificationStateVariant &VerificationStateVariant, + Span Input) noexcept; + +WasiCryptoExpect +verificationStateVerify(VerificationStateVariant &VerificationStateVariant, + const SigVariant &SigVariant) noexcept; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/aeads.cpp b/plugins/wasi_crypto/symmetric/aeads.cpp new file mode 100644 index 00000000..15f7d127 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/aeads.cpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/aeads.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" + +#include +#include + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { +using namespace std::literals; + +template +constexpr size_t Cipher::getKeySize() noexcept { + static_assert(CipherNid == NID_aes_128_gcm || CipherNid == NID_aes_256_gcm || + CipherNid == NID_chacha20_poly1305); + if constexpr (CipherNid == NID_aes_128_gcm) { + return 16; + } + if constexpr (CipherNid == NID_aes_256_gcm) { + return 32; + } + if constexpr (CipherNid == NID_chacha20_poly1305) { + return 32; + } +} + +template +WasiCryptoExpect Cipher::State::maxTagLen() const noexcept { + return getTagSize(); +} + +template +constexpr size_t Cipher::getTagSize() noexcept { + return 16; +} + +template +WasiCryptoExpect::Key> +Cipher::Key::generate(OptionalRef) noexcept { + return SecretVec::random(); +} + +template +WasiCryptoExpect::Key> +Cipher::Key::import(Span Raw) noexcept { + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::State> +Cipher::State::open(const Key &Key, + OptionalRef OptOption) noexcept { + ensureOrReturn(OptOption, __WASI_CRYPTO_ERRNO_NONCE_REQUIRED); + + std::array Nonce; + if (auto Res = OptOption->get("nonce"sv, Nonce); !Res) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_NONCE); + } else { + ensureOrReturn(*Res == NonceSize, __WASI_CRYPTO_ERRNO_INVALID_NONCE); + } + + ensureOrReturn(getKeySize() == Key.ref().size(), + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + + EvpCipherCtxPtr Ctx{EVP_CIPHER_CTX_new()}; + opensslCheck(EVP_CipherInit_ex(Ctx.get(), EVP_get_cipherbynid(CipherNid), + nullptr, Key.ref().data(), Nonce.data(), + Mode::Unchanged)); + + return State{std::move(Ctx), Nonce}; +} + +template +WasiCryptoExpect +Cipher::State::optionsGet(std::string_view Name, + Span Value) const noexcept { + ensureOrReturn(Name == "nonce"sv, __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + ensureOrReturn(NonceSize <= Value.size(), __WASI_CRYPTO_ERRNO_OVERFLOW); + + std::copy(Ctx->Nonce.begin(), Ctx->Nonce.end(), Value.begin()); + return NonceSize; +} + +// https://wiki.openssl.org/index.php/EVP_Authenticated_Encryption_and_Decryption +template +WasiCryptoExpect +Cipher::State::absorb(Span Data) noexcept { + ensureOrReturn(Data.size() <= + static_cast(std::numeric_limits::max()), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + int DataSize = static_cast(Data.size()); + + int ActualAbsorbSize; + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CipherUpdate(Ctx->RawCtx.get(), nullptr, &ActualAbsorbSize, + Data.data(), DataSize)); + } + ensureOrReturn(ActualAbsorbSize == DataSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return {}; +} + +template +WasiCryptoExpect +Cipher::State::encrypt(Span Out, + Span Data) noexcept { + return encryptImpl(Out.first(Data.size()), Out.last(getTagSize()), Data); +} + +template +WasiCryptoExpect +Cipher::State::encryptDetached(Span Out, + Span Data) noexcept { + SecretVec Tag(getTagSize()); + if (auto Res = encryptImpl(Out, Tag, Data); !Res) { + return WasiCryptoUnexpect(Res); + } + return Tag; +} + +template +WasiCryptoExpect +Cipher::State::encryptImpl(Span Out, Span Tag, + Span Data) noexcept { + ensureOrReturn(Data.size() <= + static_cast(std::numeric_limits::max()), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + int DataSize = static_cast(Data.size()); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CipherInit_ex(Ctx->RawCtx.get(), nullptr, nullptr, nullptr, + nullptr, Mode::Encrypt)); + + int ActualUpdateSize; + opensslCheck(EVP_CipherUpdate(Ctx->RawCtx.get(), Out.data(), + &ActualUpdateSize, Data.data(), DataSize)); + + int ActualFinalSize; + ensureOrReturn( + EVP_CipherFinal_ex(Ctx->RawCtx.get(), nullptr, &ActualFinalSize), + __WASI_CRYPTO_ERRNO_INTERNAL_ERROR); + + ensureOrReturn(static_cast(ActualUpdateSize + ActualFinalSize) == + Out.size(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + opensslCheck(EVP_CIPHER_CTX_ctrl(Ctx->RawCtx.get(), EVP_CTRL_AEAD_GET_TAG, + static_cast(getTagSize()), Tag.data())); + + return Out.size() + getTagSize(); +} + +template +WasiCryptoExpect +Cipher::State::decrypt(Span Out, + Span Data) noexcept { + return decryptImpl(Out, Data.first(Out.size()), Data.last(getTagSize())); +} + +template +WasiCryptoExpect +Cipher::State::decryptDetached(Span Out, + Span Data, + Span RawTag) noexcept { + return decryptImpl(Out, Data, RawTag); +} + +template +WasiCryptoExpect +Cipher::State::decryptImpl(Span Out, + Span Data, + Span RawTag) noexcept { + ensureOrReturn(Data.size() <= + static_cast(std::numeric_limits::max()), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + int DataSize = static_cast(Data.size()); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CipherInit_ex(Ctx->RawCtx.get(), nullptr, nullptr, nullptr, + nullptr, Mode::Decrypt)); + int ActualUpdateSize; + opensslCheck(EVP_CipherUpdate(Ctx->RawCtx.get(), Out.data(), + &ActualUpdateSize, Data.data(), DataSize)); + + opensslCheck(EVP_CIPHER_CTX_ctrl(Ctx->RawCtx.get(), EVP_CTRL_AEAD_SET_TAG, + static_cast(getTagSize()), + const_cast(RawTag.data()))); + + int ActualFinalSize; + if (!EVP_CipherFinal_ex(Ctx->RawCtx.get(), nullptr, &ActualFinalSize)) { + OPENSSL_cleanse(Out.data(), Out.size()); + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_TAG); + } + ensureOrReturn(ActualFinalSize + ActualUpdateSize == DataSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Out.size(); +} + +template +WasiCryptoExpect::State> +Cipher::State::clone() const noexcept { + // XXX: These cipher didn't implement context duplication from OpenSSL 3.0.0 + // https://github.com/openssl/openssl/issues/20978 + if (0x30000000 <= OPENSSL_VERSION_NUMBER && + (CipherNid == NID_aes_128_gcm || CipherNid == NID_aes_256_gcm || + CipherNid == NID_chacha20_poly1305)) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + + EvpCipherCtxPtr CloneCtx{EVP_CIPHER_CTX_new()}; + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CIPHER_CTX_copy(CloneCtx.get(), Ctx->RawCtx.get())); + } + + return State{std::move(CloneCtx), Ctx->Nonce}; +} + +template class Cipher; +template class Cipher; +template class Cipher; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/aeads.h b/plugins/wasi_crypto/symmetric/aeads.h new file mode 100644 index 00000000..d44123c5 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/aeads.h @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/aeads.h - Aeads related ----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Aeads and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Aeads invalid operations. Every Aeads state should inherit from this class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#aeads +template class AEADsState { +public: + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect squeeze(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Cipher { + static inline constexpr size_t NonceSize = 12; + +public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public AEADsState { + public: + /// There are four inputs for authenticated encryption: + /// @param[in] Key The secret key for encryption. + /// @param[in] OptOption `Must` contain a Nonce (initialization vector). + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + State(EvpCipherCtxPtr Ctx, std::array Nonce) noexcept + : Ctx(std::make_shared(std::move(Ctx), Nonce)) {} + + WasiCryptoExpect optionsGet(std::string_view Name, + Span Value) const noexcept; + + /// Absorbs additional data. Multiple calls to absorb() MUST be equivalent + /// to a single call with a concatenation of the inputs. + /// + /// @param Data Additional data + /// @return Nothing or WasiCrypto error + WasiCryptoExpect absorb(Span Data) noexcept; + + /// Return the length required to encode the authentication tag and the + /// optional padding bytes. The returned length MUST be constant for a given + /// algorithm. Guest applications are expected to provide an output buffer + /// whose size is the size of the message, plus the max_tag_len() output + /// value. + /// + /// @return the length required to encode the authentication tag + /// and optional padding bytes. + WasiCryptoExpect maxTagLen() const noexcept; + + /// Check Out.size() == Data.size() + maxTagLen(), then call + /// encryptUnchecked(Out, Data), or return an error if they are not equal. + /// + /// @param Out The encrypted data text + /// @param Data The data to be encrypted + /// @return Tag's size or + /// `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect encrypt(Span Out, + Span Data) noexcept; + + /// Check Out.size() == Data.size(), then call + /// encryptDetachedUnchecked(Out, Data), or return an error if they are not + /// equal. + /// + /// @param Out The encrypted data text + /// @param Data The data to be encrypted + /// @return Tag + /// or `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect encryptDetached(Span Out, + Span Data) noexcept; + + /// Check Out.size() = Data.size() + maxTagLen(), then call + /// decryptDetachedUnchecked(Out, Data), or return an error if they are not + /// equal. + /// + /// @param Out The decrypted data text + /// @param Data The data to be decrypted + /// @return Size or + /// `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect decrypt(Span Out, + Span Data) noexcept; + + /// Check Out.size() == Data.size(), then call + /// encryptDetachedUnchecked(Out, Data), or return an error if they are not + /// equal. + /// + /// @param Out The decrypted data text + /// @param Data The data to be decrypted + /// @return Size or + /// `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect + decryptDetached(Span Out, Span Data, + Span RawTag) noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + WasiCryptoExpect encryptImpl(Span Out, Span Tag, + Span Data) noexcept; + + WasiCryptoExpect decryptImpl(Span Out, + Span Data, + Span RawTag) noexcept; + struct Inner { + Inner(EvpCipherCtxPtr RawCtx, + std::array Nonce) noexcept + : RawCtx(std::move(RawCtx)), Nonce(Nonce) {} + EvpCipherCtxPtr RawCtx; + const std::array Nonce; + std::mutex Mutex; + }; + std::shared_ptr Ctx; + }; + +private: + enum Mode { Unchanged = -1, Decrypt = 0, Encrypt = 1 }; + + constexpr static size_t getKeySize() noexcept; + + constexpr static size_t getTagSize() noexcept; +}; + +using Aes128Gcm = Cipher; +using Aes256Gcm = Cipher; +using ChaCha20Poly1305 = Cipher; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/ctx.cpp b/plugins/wasi_crypto/symmetric/ctx.cpp new file mode 100644 index 00000000..c0dab4e3 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/ctx.cpp @@ -0,0 +1,312 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ctx.h" +#include "symmetric/key.h" +#include "symmetric/state.h" +#include "symmetric/tag.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect +Context::symmetricTagLen(__wasi_symmetric_tag_t TagHandle) noexcept { + return SymmetricTagManager.get(TagHandle).map(&Symmetric::Tag::len); +} + +WasiCryptoExpect +Context::symmetricTagPull(__wasi_symmetric_tag_t TagHandle, + Span Buf) noexcept { + return SymmetricTagManager.get(TagHandle).and_then( + [Buf](const Symmetric::Tag &Tag) noexcept { return Tag.pull(Buf); }); +} + +WasiCryptoExpect +Context::symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag) noexcept { + return SymmetricTagManager.get(TagHandle).and_then( + [RawTag](const Symmetric::Tag &Tag) noexcept { + return Tag.verify(RawTag); + }); +} + +WasiCryptoExpect +Context::symmetricTagClose(__wasi_symmetric_tag_t TagHandle) noexcept { + return SymmetricTagManager.close(TagHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) noexcept { + return SymmetricKeyManager.get(KeyHandle) + .map(Symmetric::keyExportData) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::symmetricKeyClose(__wasi_symmetric_key_t SymmetricKey) noexcept { + return SymmetricKeyManager.close(SymmetricKey); +} + +WasiCryptoExpect +Context::symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, + Span Value) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateOptionsGet(std::forward(State), + Name, Value); + }); +} + +WasiCryptoExpect +Context::symmetricStateOptionsGetU64(__wasi_symmetric_state_t StateHandle, + std::string_view Name) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Name](auto &&State) noexcept { + return Symmetric::stateOptionsGetU64( + std::forward(State), Name); + }); +} + +WasiCryptoExpect +Context::symmetricStateClose(__wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.close(StateHandle); +} + +WasiCryptoExpect +Context::symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Data](auto &&State) noexcept { + return Symmetric::stateAbsorb(State, Data); + }); +} + +WasiCryptoExpect +Context::symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Out](auto &&State) noexcept { + return Symmetric::stateSqueeze(State, Out); + }); +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> Context::symmetricStateSqueezeTag( + __wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([](auto &&State) noexcept { + return Symmetric::stateSqueezeTag(State); + }) + .and_then([this](auto &&Tag) { + return SymmetricTagManager.registerManager( + std::forward(Tag)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + Symmetric::Algorithm Alg) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Alg](auto &&State) noexcept { + return Symmetric::stateSqueezeKey(State, Alg); + }) + .and_then([this](auto &&Key) { + return SymmetricKeyManager.registerManager( + std::forward(Key)); + }); +} + +WasiCryptoExpect Context::symmetricStateMaxTagLen( + __wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then(&Symmetric::stateMaxTagLen); +} + +WasiCryptoExpect +Context::symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateEncrypt(State, Out, Data); + }); +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> +Context::symmetricStateEncryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateEncryptDetached(State, Out, Data); + }) + .and_then([this](auto &&Tag) noexcept { + return SymmetricTagManager.registerManager( + std::forward(Tag)); + }); +} + +WasiCryptoExpect +Context::symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateDecrypt(State, Out, Data); + }); +} + +WasiCryptoExpect Context::symmetricStateDecryptDetached( + __wasi_symmetric_state_t StateHandle, Span Out, + Span Data, Span RawTag) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateDecryptDetached(State, Out, Data, RawTag); + }); +} + +WasiCryptoExpect +Context::symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then( + [](auto &&State) noexcept { return Symmetric::stateRatchet(State); }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyImport(Symmetric::Algorithm Alg, + Span Raw) noexcept { + return Symmetric::importKey(Alg, Raw).and_then([this](auto &&Key) noexcept { + return SymmetricKeyManager.registerManager( + std::forward(Key)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyGenerate(Symmetric::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept { + auto OptOptionsResult = mapAndTransposeOptional( + OptOptionsHandle, [this](__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.get(OptionsHandle); + }); + if (!OptOptionsResult) { + return WasiCryptoUnexpect(OptOptionsResult); + } + + // Refer to OptOptionsResult if it's a Symmetric::Options. + auto OptSymmetricOptionsResult = transposeOptionalToRef( + *OptOptionsResult, + [](const auto &Options) noexcept + -> WasiCryptoExpect> { + auto *SymmetricOptions = std::get_if(&Options); + if (!SymmetricOptions) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return SymmetricOptions; + }); + if (!OptSymmetricOptionsResult) { + return WasiCryptoUnexpect(OptSymmetricOptionsResult); + } + + return Symmetric::generateKey(Alg, *OptSymmetricOptionsResult) + .and_then([this](auto &&Key) noexcept { + return SymmetricKeyManager.registerManager( + std::forward(Key)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_state_t> +Context::symmetricStateOpen(Symmetric::Algorithm Alg, + __wasi_opt_symmetric_key_t OptKeyHandle, + __wasi_opt_options_t OptOptionsHandle) noexcept { + // Copy from KeyManager. + auto OptKeyResult = + mapAndTransposeOptional(OptKeyHandle, + [this](__wasi_symmetric_key_t KeyHandle) noexcept + -> WasiCryptoExpect { + return SymmetricKeyManager.get(KeyHandle); + }); + if (!OptKeyResult) { + return WasiCryptoUnexpect(OptKeyResult); + } + + // Copy from OptionsManager. + auto OptOptionsResult = mapAndTransposeOptional( + OptOptionsHandle, [this](__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.get(OptionsHandle); + }); + if (!OptOptionsResult) { + return WasiCryptoUnexpect(OptOptionsResult); + } + + // Refer to OptOptionsResult if it's a Smmetric::Options. + auto OptSymmetricOptionsResult = transposeOptionalToRef( + *OptOptionsResult, + [](const auto &Options) noexcept + -> WasiCryptoExpect> { + auto *SymmetricOptions = std::get_if(&Options); + if (!SymmetricOptions) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return SymmetricOptions; + }); + if (!OptSymmetricOptionsResult) { + return WasiCryptoUnexpect(OptSymmetricOptionsResult); + } + + return Symmetric::openState(Alg, asOptionalRef(*OptKeyResult), + *OptSymmetricOptionsResult) + .and_then([this](auto &&State) noexcept { + return SymmetricStateManager.registerManager( + std::forward(State)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_state_t> +Context::symmetricStateClone(__wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then(&Symmetric::stateClone) + .and_then([this](auto &&State) noexcept { + return SymmetricStateManager.registerManager( + std::forward(State)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyGenerateManaged(__wasi_secrets_manager_t, + Symmetric::Algorithm, + __wasi_opt_options_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect Context::symmetricKeyStoreManaged( + __wasi_secrets_manager_t, __wasi_symmetric_key_t, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_version_t> +Context::symmetricKeyReplaceManaged(__wasi_secrets_manager_t, + __wasi_symmetric_key_t, + __wasi_symmetric_key_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect> +Context::symmetricKeyId(__wasi_symmetric_key_t, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyFromId(__wasi_secrets_manager_t, Span, + __wasi_version_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/func.cpp b/plugins/wasi_crypto/symmetric/func.cpp new file mode 100644 index 00000000..24ceb43a --- /dev/null +++ b/plugins/wasi_crypto/symmetric/func.cpp @@ -0,0 +1,687 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +Expect KeyGenerate::body(const Runtime::CallingFrame &Frame, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + Algorithm WasiAlg; + if (auto Res = tryFrom(Alg); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptions = + MemInst->getPointer(OptOptionsPtr); + checkExist(OptOptions); + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = Ctx.symmetricKeyGenerate(WasiAlg, *OptOptions); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyImport::body(const Runtime::CallingFrame &Frame, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t RawPtr, uint32_t RawLen, + uint32_t /* Out */ KeyPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + Algorithm WasiAlg; + if (auto Res = tryFrom(Alg); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const Key = MemInst->getPointer<__wasi_symmetric_key_t *>(KeyPtr); + checkExist(Key); + + const __wasi_size_t WasiRawLen = RawLen; + const auto Raw = MemInst->getSpan(RawPtr, WasiRawLen); + checkRangeExist(Raw, WasiRawLen); + + if (auto Res = Ctx.symmetricKeyImport(WasiAlg, Raw); unlikely(!Res)) { + return Res.error(); + } else { + *Key = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyExport::body(const Runtime::CallingFrame &Frame, + int32_t KeyHandle, + uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.symmetricKeyExport(KeyHandle); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyClose::body(const Runtime::CallingFrame &, + int32_t KeyHandle) { + if (auto Res = Ctx.symmetricKeyClose(KeyHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyGenerateManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + + Algorithm WasiAlg; + if (auto Res = tryFrom(Alg); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptions = + MemInst->getPointer(OptOptionsPtr); + checkExist(OptOptions); + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = Ctx.symmetricKeyGenerateManaged(SecretsManagerHandle, WasiAlg, + *OptOptions); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyStoreManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t KeyHandle, uint32_t KeyIdPtr, + uint32_t KeyIdMaxLen) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdMaxLen); + checkRangeExist(KeyId, WasiKeyIdMaxLen); + + if (auto Res = + Ctx.symmetricKeyStoreManaged(SecretsManagerHandle, KeyHandle, KeyId); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyReplaceManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t OldKeyHandle, + int32_t NewKeyHandle, + uint32_t /* Out */ KeyVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const KeyVersion = + MemInst->getPointer<__wasi_version_t *>(KeyVersionPtr); + checkExist(KeyVersion); + + if (auto Res = Ctx.symmetricKeyReplaceManaged(SecretsManagerHandle, + OldKeyHandle, NewKeyHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyVersion = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyId::body(const Runtime::CallingFrame &Frame, + int32_t KeyHandle, uint32_t KeyIdPtr, + uint32_t KeyIdMaxLen, uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KeyVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdMaxLen); + checkRangeExist(KeyId, WasiKeyIdMaxLen); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + auto *const KeyVersion = + MemInst->getPointer<__wasi_version_t *>(KeyVersionPtr); + if (unlikely(KeyVersion == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricKeyId(KeyHandle, KeyId); unlikely(!Res)) { + return Res.error(); + } else { + auto [SizeRes, VersionRes] = *Res; + auto SafeSizeRes = toWasiSize(SizeRes); + if (unlikely(!SafeSizeRes)) { + return SafeSizeRes.error(); + } + + *KeyVersion = VersionRes; + *Size = *SafeSizeRes; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyFromId::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + uint32_t KeyIdPtr, uint32_t KeyIdLen, + uint64_t KeyVersion, + uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdLen = KeyIdLen; + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdLen); + checkRangeExist(KeyId, WasiKeyIdLen); + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = + Ctx.symmetricKeyFromId(SecretsManagerHandle, KeyId, KeyVersion); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateOpen::body(const Runtime::CallingFrame &Frame, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptKeyHandlePtr, + uint32_t OptOptionsPtr, + uint32_t /* Out */ StatePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + Algorithm WasiAlg; + if (auto Res = tryFrom(Alg); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptKeyHandle = + MemInst->getPointer(OptKeyHandlePtr); + checkExist(OptKeyHandle); + + auto *const OptOptions = + MemInst->getPointer(OptOptionsPtr); + checkExist(OptOptions); + + auto *const State = MemInst->getPointer<__wasi_symmetric_state_t *>(StatePtr); + if (unlikely(State == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricStateOpen(WasiAlg, *OptKeyHandle, *OptOptions); + unlikely(!Res)) { + return Res.error(); + } else { + *State = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateClone::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, + uint32_t /* Out */ StatePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const State = MemInst->getPointer<__wasi_symmetric_state_t *>(StatePtr); + if (unlikely(State == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricStateClone(StateHandle); unlikely(!Res)) { + return Res.error(); + } else { + *State = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateOptionsGet::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t ValuePtr, + uint32_t ValueLen, + uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); + + const __wasi_size_t WasiValueLen = ValueLen; + const auto Value = MemInst->getSpan(ValuePtr, WasiValueLen); + checkRangeExist(Value, WasiValueLen); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricStateOptionsGet(StateHandle, Name, Value) + .and_then(toWasiSize); + !Res) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateOptionsGetU64::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t NamePtr, + uint32_t NameLen, + uint32_t /* Out */ U64Ptr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); + + auto *const U64 = MemInst->getPointer(U64Ptr); + checkExist(U64); + + if (auto Res = Ctx.symmetricStateOptionsGetU64(StateHandle, Name); + unlikely(!Res)) { + return Res.error(); + } else { + *U64 = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateClose::body(const Runtime::CallingFrame &, + int32_t StateHandle) { + if (auto Res = Ctx.symmetricStateClose(StateHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateAbsorb::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t DataPtr, + uint32_t DataLen) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiDataLen = DataLen; + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); + + if (auto Res = Ctx.symmetricStateAbsorb(StateHandle, Data); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateSqueeze::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); + + if (auto Res = Ctx.symmetricStateSqueeze(StateHandle, Out); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateSqueezeTag::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, + uint32_t /* Out */ TagHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const TagHandle = + MemInst->getPointer<__wasi_symmetric_tag_t *>(TagHandlePtr); + checkExist(TagHandle); + + if (auto Res = Ctx.symmetricStateSqueezeTag(StateHandle); unlikely(!Res)) { + return Res.error(); + } else { + *TagHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateSqueezeKey::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t AlgPtr, + uint32_t AlgLen, + uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + Algorithm WasiAlg; + if (auto Res = tryFrom(Alg); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = Ctx.symmetricStateSqueezeKey(StateHandle, WasiAlg); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateMaxTagLen::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, + uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricStateMaxTagLen(StateHandle).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateEncrypt::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, + uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); + + const __wasi_size_t WasiDataLen = DataLen; + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricStateEncrypt(StateHandle, Out, Data) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateEncryptDetached::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, + uint32_t DataPtr, uint32_t DataLen, + uint32_t /* Out */ TagHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); + + const __wasi_size_t WasiDataLen = DataLen; + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); + + auto *const TagHandle = + MemInst->getPointer<__wasi_symmetric_tag_t *>(TagHandlePtr); + checkExist(TagHandle); + + if (auto Res = Ctx.symmetricStateEncryptDetached(StateHandle, Out, Data); + unlikely(!Res)) { + return Res.error(); + } else { + *TagHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateDecrypt::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, + uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); + + const __wasi_size_t WasiDataLen = DataLen; + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + if (unlikely(Size == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricStateDecrypt(StateHandle, Out, Data) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateDecryptDetached::body( + const Runtime::CallingFrame &Frame, int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, uint32_t RawTagPtr, + uint32_t RawTagLen, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); + + const __wasi_size_t WasiDataLen = DataLen; + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); + + const __wasi_size_t WasiRawTagLen = RawTagLen; + const auto RawTag = MemInst->getSpan(RawTagPtr, WasiRawTagLen); + checkRangeExist(RawTag, WasiRawTagLen); + + auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = + Ctx.symmetricStateDecryptDetached(StateHandle, Out, Data, RawTag) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateRatchet::body(const Runtime::CallingFrame &, + int32_t StateHandle) { + if (auto Res = Ctx.symmetricStateRatchet(StateHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagLen::body(const Runtime::CallingFrame &Frame, + int32_t TagHandle, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + if (unlikely(Size == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricTagLen(TagHandle).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagPull::body(const Runtime::CallingFrame &Frame, + int32_t TagHandle, uint32_t BufPtr, + uint32_t BufLen, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiBufLen = BufLen; + const auto Buf = MemInst->getSpan(BufPtr, WasiBufLen); + checkRangeExist(Buf, WasiBufLen); + + auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricTagPull(TagHandle, Buf).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagVerify::body(const Runtime::CallingFrame &Frame, + int32_t TagHandle, uint32_t RawTagPtr, + uint32_t RawTagLen) { + auto *MemInst = Frame.getMemoryByIndex(0); + checkExist(MemInst); + + const __wasi_size_t WasiRawTagLen = RawTagLen; + const auto RawTag = MemInst->getSpan(RawTagPtr, WasiRawTagLen); + checkRangeExist(RawTag, WasiRawTagLen); + + if (auto Res = Ctx.symmetricTagVerify(TagHandle, RawTag); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagClose::body(const Runtime::CallingFrame &, + int32_t TagHandle) { + if (auto Res = Ctx.symmetricTagClose(TagHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/func.h b/plugins/wasi_crypto/symmetric/func.h new file mode 100644 index 00000000..eded49a3 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/func.h @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/func.h - Symmetric funcs ---===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the symmetric host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/hostfunction.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +class KeyGenerate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr); +}; + +class KeyImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t RawPtr, uint32_t RawLen, + uint32_t /* Out */ KeyHandlePtr); +}; + +class KeyExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KeyHandle, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class KeyClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KeyHandle); +}; + +class KeyGenerateManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr); +}; + +class KeyStoreManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, int32_t KeyHandle, + uint32_t KeyIdPtr, uint32_t KeyIdMaxLen); +}; + +class KeyReplaceManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, int32_t OldKeyHandle, + int32_t NewKeyHandle, uint32_t /* Out */ KeyVersionPtr); +}; + +class KeyId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t KeyHandle, + uint32_t KeyIdPtr, uint32_t KeyIdMaxLen, + uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KeyVersionPtr); +}; + +class KeyFromId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, uint32_t KeyIdPtr, + uint32_t KeyIdLen, uint64_t KeyVersion, + uint32_t /* Out */ KeyHandlePtr); +}; + +class StateOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t OptKeyHandlePtr, + uint32_t OptOptionsPtr, uint32_t /* Out */ StatePtr); +}; + +class StateClone : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t /* Out */ StatePtr); +}; + +class StateOptionsGet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t NamePtr, uint32_t NameLen, uint32_t ValuePtr, + uint32_t ValueLen, uint32_t /* Out */ SizePtr); +}; + +class StateOptionsGetU64 : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t NamePtr, uint32_t NameLen, + uint32_t /* Out */ U64Ptr); +}; + +class StateClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t StateHandle); +}; + +class StateAbsorb : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t DataPtr, uint32_t DataLen); +}; + +class StateSqueeze : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen); +}; + +class StateSqueezeTag : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t /* Out */ TagHandlePtr); +}; + +class StateSqueezeKey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t /* Out */ KeyHandlePtr); +}; + +class StateMaxTagLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t /* Out */ SizePtr); +}; + +class StateEncrypt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t /* Out */ SizePtr); +}; + +class StateEncryptDetached : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t /* Out */ TagHandlePtr); +}; + +class StateDecrypt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t /* Out */ SizePtr); +}; + +class StateDecryptDetached : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t RawTagPtr, + uint32_t RawTagLen, uint32_t /* Out */ SizePtr); +}; + +class StateRatchet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t StateHandle); +}; + +class TagLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle, + uint32_t /* Out */ SizePtr); +}; + +class TagPull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle, + uint32_t BufPtr, uint32_t BufLen, + uint32_t /* Out */ SizePtr); +}; + +class TagVerify : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle, + uint32_t RawTagPtr, uint32_t RawTagLen); +}; + +class TagClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle); +}; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/hash.cpp b/plugins/wasi_crypto/symmetric/hash.cpp new file mode 100644 index 00000000..2e80f856 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/hash.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/hash.h" +#include "utils/evp_wrapper.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +template constexpr size_t Sha2::getDigestSize() noexcept { + static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512 || + ShaNid == NID_sha512_256); + if constexpr (ShaNid == NID_sha256) + return 32; + if constexpr (ShaNid == NID_sha512) + return 64; + if constexpr (ShaNid == NID_sha512_256) + return 32; +} + +template +WasiCryptoExpect::State> +Sha2::State::open(OptionalRef) noexcept { + EvpMdCtxPtr Ctx{EVP_MD_CTX_new()}; + opensslCheck(EVP_DigestInit(Ctx.get(), EVP_get_digestbynid(ShaNid))); + return Ctx; +} + +template +WasiCryptoExpect +Sha2::State::absorb(Span Data) noexcept { + std::unique_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect +Sha2::State::squeeze(Span Out) noexcept { + ensureOrReturn(getDigestSize() >= Out.size(), + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + + EvpMdCtxPtr CopyCtx{EVP_MD_CTX_new()}; + + { + std::shared_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CopyCtx.get(), Ctx->RawCtx.get())); + } + + if (getDigestSize() == Out.size()) { + unsigned int Size; + opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Out.data(), &Size)); + ensureOrReturn(Size == getDigestSize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + if (getDigestSize() > Out.size()) { + unsigned int Size; + std::array Cache; + opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Cache.data(), &Size)); + ensureOrReturn(Size == getDigestSize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + std::copy(Cache.begin(), Cache.begin() + static_cast(Out.size()), + Out.data()); + } + + return {}; +} + +template +WasiCryptoExpect::State> +Sha2::State::clone() const noexcept { + EvpMdCtxPtr CloneCtx{EVP_MD_CTX_new()}; + + { + std::shared_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CloneCtx.get(), Ctx->RawCtx.get())); + } + + return CloneCtx; +} + +template class Sha2; +template class Sha2; +template class Sha2; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/hash.h b/plugins/wasi_crypto/symmetric/hash.h new file mode 100644 index 00000000..7baf01b0 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/hash.h @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/hash.h - Hash related ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Hash and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Hash never has a key; it is only a placeholder. Every hash key should +/// inherit from this class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#hash-functions +template class HashKey { +public: + static WasiCryptoExpect import(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + + static WasiCryptoExpect generate(OptionalRef) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + + SecretVec exportData() const noexcept { assumingUnreachable(); } +}; + +/// Hash invalid operations. Every hash state should inherit from this class. +template class HashState { +public: + /// The current hash does not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// The current hash does not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Sha2 { +public: + /// In fact, sha2 keys are never produced. This design removes the + /// forwarding declaration. + class Key : public HashKey {}; + + class State : public HashState { + public: + State(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + static WasiCryptoExpect + open(OptionalRef OptOption) noexcept; + + WasiCryptoExpect absorb(Span Data) noexcept; + + WasiCryptoExpect squeeze(Span Out) noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} + EvpMdCtxPtr RawCtx; + std::shared_mutex Mutex; + }; + std::shared_ptr Ctx; + }; + +private: + /// Return the sha digest size. + constexpr static size_t getDigestSize() noexcept; +}; + +using Sha256 = Sha2; +using Sha512 = Sha2; +using Sha512_256 = Sha2; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/kdf.cpp b/plugins/wasi_crypto/symmetric/kdf.cpp new file mode 100644 index 00000000..62f43c2f --- /dev/null +++ b/plugins/wasi_crypto/symmetric/kdf.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "symmetric/kdf.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/secret_vec.h" + +#include +#include + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +template constexpr uint32_t Hkdf::getKeySize() noexcept { + static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512); + + if constexpr (ShaNid == NID_sha256) + return 32; + if constexpr (ShaNid == NID_sha512) + return 64; +} + +template +constexpr const EVP_MD *Hkdf::getShaCtx() noexcept { + return EVP_get_digestbynid(ShaNid); +} + +template +WasiCryptoExpect::Expand::Key> +Hkdf::Expand::Key::generate(OptionalRef) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_FEATURE); +} + +template +WasiCryptoExpect::Expand::Key> +Hkdf::Expand::Key::import(Span Raw) noexcept { + ensureOrReturn(Raw.size() == getKeySize(), __WASI_CRYPTO_ERRNO_INVALID_KEY); + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::Expand::State> +Hkdf::Expand::State::open(const Key &Key, + OptionalRef) noexcept { + return openStateImpl(Key.ref(), EVP_PKEY_HKDEF_MODE_EXPAND_ONLY); +} + +template +WasiCryptoExpect +Hkdf::Expand::State::absorb(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_PKEY_CTX_add1_hkdf_info(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect +Hkdf::Expand::State::squeeze(Span Out) noexcept { + size_t KeyLen = Out.size(); + + { + std::scoped_lock Lock{Ctx->Mutex}; + ensureOrReturn(EVP_PKEY_derive(Ctx->RawCtx.get(), Out.data(), &KeyLen), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + } + + return {}; +} + +template +WasiCryptoExpect::Expand::State> +Hkdf::Expand::State::clone() const noexcept { + // not supported for a keygen operation. + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +template +WasiCryptoExpect::Extract::Key> +Hkdf::Extract::Key::generate(OptionalRef) noexcept { + return SecretVec::random(); +} + +template +WasiCryptoExpect::Extract::Key> +Hkdf::Extract::Key::import(Span Raw) noexcept { + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::Extract::State> +Hkdf::Extract::State::open(const Key &Key, + OptionalRef) noexcept { + return openStateImpl(Key.ref(), EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY); +} + +template +WasiCryptoExpect +Hkdf::Extract::State::absorb(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + + Ctx->Salt.insert(Ctx->Salt.end(), Data.begin(), Data.end()); + return {}; +} + +template +WasiCryptoExpect::Expand::Key> +Hkdf::Extract::State::squeezeKey() noexcept { + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_PKEY_CTX_set1_hkdf_salt( + Ctx->RawCtx.get(), Ctx->Salt.data(), Ctx->Salt.size())); + } + size_t ActualOutSize = getKeySize(); + SecretVec Data(ActualOutSize); + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_PKEY_derive(Ctx->RawCtx.get(), Data.data(), &ActualOutSize)); + } + ensureOrReturn(ActualOutSize == getKeySize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Data; +} + +template +WasiCryptoExpect::Extract::State> +Hkdf::Extract::State::clone() const noexcept { + // not supported for a keygen operation. + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +template +WasiCryptoExpect +Hkdf::openStateImpl(Span Key, int Mode) noexcept { + EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, nullptr)}; + opensslCheck(EVP_PKEY_derive_init(Ctx.get())); + opensslCheck(EVP_PKEY_CTX_set_hkdf_md(Ctx.get(), getShaCtx())); + opensslCheck(EVP_PKEY_CTX_hkdf_mode(Ctx.get(), Mode)); + ensureOrReturn(EVP_PKEY_CTX_set1_hkdf_key(Ctx.get(), Key.data(), Key.size()), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; +} + +template class Hkdf::Extract; +template class Hkdf::Extract; +template class Hkdf::Expand; +template class Hkdf::Expand; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/kdf.h b/plugins/wasi_crypto/symmetric/kdf.h new file mode 100644 index 00000000..e17f77b2 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/kdf.h @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/kdf.h - Kdf related --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Key derivation and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Expand invalid operations. Every expand state should inherit from this +/// class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#key-derivation-using-extract-and-expand +template class ExpandState { +public: + /// The current kdf does not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// The current kdf does not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +/// Extract invalid operations. Every extract state should inherit from this +/// class. +template class ExtractState { +public: + /// The current kdf does not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// The current kdf does not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect squeeze(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Hkdf { +public: + class Expand { + public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public ExpandState { + public: + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + State(EvpPkeyCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + /// absorb info information. + WasiCryptoExpect absorb(Span Data) noexcept; + + /// derivation + WasiCryptoExpect squeeze(Span Out) noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpPkeyCtxPtr RawCtx) : RawCtx(std::move(RawCtx)) {} + EvpPkeyCtxPtr RawCtx; + std::mutex Mutex; + }; + std::shared_ptr Ctx; + }; + }; + + class Extract { + public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public ExtractState { + public: + State(EvpPkeyCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + /// Absorbs the salt information. + WasiCryptoExpect absorb(Span Data) noexcept; + + /// Returns the PRK, whose algorithm type is set to the EXPAND counterpart + /// of the EXTRACT operation. + WasiCryptoExpect squeezeKey() noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpPkeyCtxPtr RawCtx) : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + std::vector Salt; + EvpPkeyCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + }; + +private: + constexpr static uint32_t getKeySize() noexcept; + + constexpr static const EVP_MD *getShaCtx() noexcept; + + static WasiCryptoExpect openStateImpl(Span Key, + int Mode) noexcept; +}; + +using HkdfSha256Extract = Hkdf::Extract; +using HkdfSha512Extract = Hkdf::Extract; +using HkdfSha256Expand = Hkdf::Expand; +using HkdfSha512Expand = Hkdf::Expand; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/key.cpp b/plugins/wasi_crypto/symmetric/key.cpp new file mode 100644 index 00000000..970d1ace --- /dev/null +++ b/plugins/wasi_crypto/symmetric/key.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/key.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +WasiCryptoExpect importKey(Algorithm Alg, + Span Data) noexcept { + return std::visit( + [Data](auto Factory) noexcept { + return decltype(Factory)::Key::import(Data).map( + [](auto &&Key) noexcept { + return KeyVariant{std::forward(Key)}; + }); + }, + Alg); +} + +WasiCryptoExpect +generateKey(Algorithm Alg, OptionalRef OptOptions) noexcept { + return std::visit( + [OptOptions](auto Factory) noexcept { + return decltype(Factory)::Key::generate(OptOptions) + .map([](auto &&Key) noexcept { + return KeyVariant{std::forward(Key)}; + }); + }, + Alg); +} + +SecretVec keyExportData(const KeyVariant &KeyVariant) noexcept { + return std::visit([](const auto &Key) noexcept { return Key.exportData(); }, + KeyVariant); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/key.h b/plugins/wasi_crypto/symmetric/key.h new file mode 100644 index 00000000..d7e4aa6d --- /dev/null +++ b/plugins/wasi_crypto/symmetric/key.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/key.h - Symmetric Key class ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Symmetric Key class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/registered.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Object represents a key and an algorithm. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#symmetric-keys-1 +using KeyVariant = RegistedAlg::Key; + +WasiCryptoExpect importKey(Algorithm Alg, + Span Data) noexcept; + +WasiCryptoExpect +generateKey(Algorithm Alg, OptionalRef OptOptions) noexcept; + +/// Get the inner represent. +SecretVec keyExportData(const KeyVariant &Key) noexcept; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/mac.cpp b/plugins/wasi_crypto/symmetric/mac.cpp new file mode 100644 index 00000000..2b0bab77 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/mac.cpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/mac.h" +#include "utils/secret_vec.h" + +#include + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +template constexpr size_t Hmac::getKeySize() noexcept { + static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512); + if constexpr (ShaNid == NID_sha256) { + return 32; + } + if constexpr (ShaNid == NID_sha512) { + return 64; + } +} + +template +WasiCryptoExpect::Key> +Hmac::Key::generate(OptionalRef) noexcept { + return SecretVec::random(); +} + +template +WasiCryptoExpect::Key> +Hmac::Key::import(Span Raw) noexcept { + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::State> +Hmac::State::open(const Key &Key, OptionalRef) noexcept { + EvpPkeyPtr HmacKey{EVP_PKEY_new_raw_private_key( + EVP_PKEY_HMAC, nullptr, Key.ref().data(), Key.ref().size())}; + opensslCheck(HmacKey); + + EvpMdCtxPtr Ctx{EVP_MD_CTX_new()}; + + opensslCheck(EVP_DigestSignInit( + Ctx.get(), nullptr, EVP_get_digestbynid(ShaNid), nullptr, HmacKey.get())); + + return Ctx; +} + +template +WasiCryptoExpect +Hmac::State::absorb(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + + opensslCheck( + EVP_DigestSignUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect Hmac::State::squeezeTag() noexcept { + size_t ActualOutSize = getKeySize(); + SecretVec Res(ActualOutSize); + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestSignFinal(Ctx->RawCtx.get(), Res.data(), &ActualOutSize)); + } + + ensureOrReturn(ActualOutSize == getKeySize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; +} + +template +WasiCryptoExpect::State> +Hmac::State::clone() const noexcept { + EvpMdCtxPtr CloneCtx{EVP_MD_CTX_new()}; + + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CloneCtx.get(), Ctx->RawCtx.get())); + } + + return CloneCtx; +} + +template class Hmac; +template class Hmac; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/mac.h b/plugins/wasi_crypto/symmetric/mac.h new file mode 100644 index 00000000..fad820b9 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/mac.h @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/mac.h - Mac related --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Mac and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Mac invalid operations. Every mac state should inherit from this class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#message-authentication-codes +template class MacState { +public: + /// The current mac does not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// The current mac does not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect squeeze(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Hmac { +public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public MacState { + public: + State(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + /// Adds input data to the state. + /// + /// @param[in] Data the input data. + /// @return Nothing or WasiCrypto error. + WasiCryptoExpect absorb(Span Data) noexcept; + + /// Authenticates the input received up to the function call. + /// If finalization is required, the implementation MUST duplicate the + /// internal state and apply the finalization on the copy, leaving the state + /// unchanged from the guest perspective. + /// + /// @return Nothing or WasiCrypto error. + WasiCryptoExpect squeezeTag() noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + EvpMdCtxPtr RawCtx; + std::mutex Mutex; + }; + std::shared_ptr Ctx; + }; + +private: + constexpr static size_t getKeySize() noexcept; +}; + +using HmacSha256 = Hmac; +using HmacSha512 = Hmac; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/module.cpp b/plugins/wasi_crypto/symmetric/module.cpp new file mode 100644 index 00000000..01540cf8 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/module.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/module.h" +#include "symmetric/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoSymmetricModule::WasiCryptoSymmetricModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_symmetric"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("symmetric_key_generate", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_import", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_export", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_close", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_generate_managed", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_store_managed", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_replace_managed", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_id", std::make_unique(*Ctx)); + addHostFunc("symmetric_key_from_id", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_open", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_clone", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_options_get", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_options_get_u64", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_close", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_absorb", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_squeeze", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_squeeze_tag", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_squeeze_key", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_max_tag_len", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_encrypt", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_encrypt_detached", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_decrypt", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_decrypt_detached", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_ratchet", + std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_len", std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_pull", std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_verify", + std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_close", + std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/module.h b/plugins/wasi_crypto/symmetric/module.h new file mode 100644 index 00000000..14052359 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/module.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/module.h - Module ----------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto symmetric module +/// class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoSymmetricModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoSymmetricModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/options.cpp b/plugins/wasi_crypto/symmetric/options.cpp new file mode 100644 index 00000000..ae2b3c5f --- /dev/null +++ b/plugins/wasi_crypto/symmetric/options.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/options.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { +using namespace std::literals; + +namespace { +constexpr std::array ValidNames{"context"sv, "salt"sv, + "nonce"sv}; + +std::string toLower(std::string_view Name) noexcept { + std::string Ret{Name}; + std::transform(Ret.begin(), Ret.end(), Ret.begin(), + [](char C) { return static_cast(std::tolower(C)); }); + return Ret; +} + +bool isValidName(std::string_view Name) noexcept { + return std::find(ValidNames.begin(), ValidNames.end(), Name) != + ValidNames.end(); +} + +constexpr std::array ValidU64Names{ + "memory_limit"sv, "ops_limit"sv, "parallelism"sv}; + +bool isValidU64Name(std::string_view Name) noexcept { + return std::find(ValidU64Names.begin(), ValidU64Names.end(), Name) != + ValidU64Names.end(); +} + +} // namespace + +WasiCryptoExpect Options::set(std::string_view Name, + Span Value) noexcept { + std::string ActuallyName = toLower(Name); + + ensureOrReturn(isValidName(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + { + std::unique_lock Lock{Inner->Mutex}; + Inner->ValueMap.insert_or_assign(ActuallyName, + std::vector(Value.begin(), Value.end())); + } + return {}; +} + +WasiCryptoExpect Options::setU64(std::string_view Name, + uint64_t Value) noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(isValidU64Name(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::unique_lock Lock{Inner->Mutex}; + Inner->U64ValueMap.insert_or_assign(ActuallyName, Value); + } + return {}; +} + +WasiCryptoExpect Options::setGuestBuffer(std::string_view Name, + Span Buffer) noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(ActuallyName == "buffer"sv, + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::unique_lock Lock{Inner->Mutex}; + Inner->GuestBuffer = Buffer; + } + return {}; +} + +WasiCryptoExpect Options::get(std::string_view Name, + Span Value) const noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(isValidName(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::shared_lock Lock{Inner->Mutex}; + if (auto It = Inner->ValueMap.find(ActuallyName); + It != Inner->ValueMap.end()) { + ensureOrReturn(It->second.size() <= Value.size(), + __WASI_CRYPTO_ERRNO_OVERFLOW); + std::copy(It->second.begin(), It->second.end(), Value.begin()); + return It->second.size(); + } + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_OPTION_NOT_SET); +} + +WasiCryptoExpect +Options::getU64(std::string_view Name) const noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(isValidU64Name(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::shared_lock Lock{Inner->Mutex}; + if (auto It = Inner->U64ValueMap.find(ActuallyName); + It != Inner->U64ValueMap.end()) { + return It->second; + } + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_OPTION_NOT_SET); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/options.h b/plugins/wasi_crypto/symmetric/options.h new file mode 100644 index 00000000..7e7d1bd1 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/options.h @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/options.h - Options --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Symmetric Options class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Options for symmetric state and key. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#options-1 +class Options { +public: + WasiCryptoExpect set(std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect setU64(std::string_view Name, uint64_t Value) noexcept; + + WasiCryptoExpect setGuestBuffer(std::string_view Name, + Span Buffer) noexcept; + + WasiCryptoExpect get(std::string_view Name, + Span Value) const noexcept; + + WasiCryptoExpect getU64(std::string_view Name) const noexcept; + +private: + struct DataType { + std::map> ValueMap; + std::map U64ValueMap; + std::optional> GuestBuffer; + mutable std::shared_mutex Mutex; + }; + + std::shared_ptr Inner = std::make_shared(); +}; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/registered.h b/plugins/wasi_crypto/symmetric/registered.h new file mode 100644 index 00000000..2bf5a622 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/registered.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/registered.h - Registered --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register symmetric algorithm definitions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/aeads.h" +#include "symmetric/hash.h" +#include "symmetric/kdf.h" +#include "symmetric/mac.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Registered algorithm +template struct Registered { + using Key = std::variant; + using State = std::variant; + using Variant = std::variant; +}; + +using RegistedAlg = + Registered; + +using Algorithm = RegistedAlg::Variant; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/state.cpp b/plugins/wasi_crypto/symmetric/state.cpp new file mode 100644 index 00000000..ce358f9d --- /dev/null +++ b/plugins/wasi_crypto/symmetric/state.cpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/state.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { +namespace { +WasiCryptoExpect checkedAdd(size_t A, size_t B) { + size_t Res; + ensureOrReturn(!__builtin_add_overflow(A, B, &Res), + __WASI_CRYPTO_ERRNO_OVERFLOW); + return Res; +} + +/// Correspond signatures: +template struct StateOpenTrait; + +/// WasiCryptoExpect open(const KeyType &, OptionalRef); +template +struct StateOpenTrait (*)( + const KeyType &, OptionalRef) noexcept> { + static inline constexpr bool NeedKey = true; + using Key = KeyType; +}; + +/// WasiCryptoExpect open(OptionalRef); +template +struct StateOpenTrait (*)( + OptionalRef) noexcept> { + static inline constexpr bool NeedKey = false; +}; + +template +using GetStateOpenTrait = StateOpenTrait; +} // namespace + +WasiCryptoExpect +openState(Algorithm Alg, OptionalRef OptKeyVariant, + OptionalRef OptOptions) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + using StateOpen = GetStateOpenTrait; + if constexpr (StateOpen::NeedKey) { + using RequiredKeyType = typename StateOpen::Key; + // Need key. Not have key, fail. + if (unlikely(!OptKeyVariant)) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_REQUIRED); + } + return std::visit( + [OptOptions](const auto &Key) -> WasiCryptoExpect { + using InKeyType = std::decay_t; + if constexpr (!std::is_same_v) { + // Key types do not match. + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_KEY); + } else { + // Key type fitted. + return decltype(Factory)::State::open(Key, OptOptions) + .map([](auto &&State) noexcept { + return StateVariant{ + std::forward(State)}; + }); + } + }, + *OptKeyVariant); + + } else { + // Not need key. Have key, fail. + if (unlikely(!!OptKeyVariant)) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + return decltype(Factory)::State::open(OptOptions) + .map([](auto &&State) noexcept { + return StateVariant{std::forward(State)}; + }); + } + }, + Alg); +} + +WasiCryptoExpect stateOptionsGet(const StateVariant &StateVariant, + std::string_view Name, + Span Value) noexcept { + return std::visit( + [=](const auto &State) noexcept { return State.optionsGet(Name, Value); }, + StateVariant); +} + +WasiCryptoExpect stateOptionsGetU64(const StateVariant &StateVariant, + std::string_view Name) noexcept { + return std::visit( + [Name](const auto &State) noexcept { return State.optionsGetU64(Name); }, + StateVariant); +} + +WasiCryptoExpect stateAbsorb(StateVariant &StateVariant, + Span Data) noexcept { + return std::visit([Data](auto &State) noexcept { return State.absorb(Data); }, + StateVariant); +} + +WasiCryptoExpect stateSqueeze(StateVariant &StateVariant, + Span Out) noexcept { + return std::visit([Out](auto &State) noexcept { return State.squeeze(Out); }, + StateVariant); +} + +WasiCryptoExpect stateSqueezeTag(StateVariant &StateVariant) noexcept { + return std::visit([](auto &State) noexcept { return State.squeezeTag(); }, + StateVariant); +} + +namespace { +template struct GetSqueezeKeyTypeTrait; +template +struct GetSqueezeKeyTypeTrait ( + StateType::*)() noexcept> { + using Key = KeyType; +}; +template +using GetSqueezeKeyType = + typename GetSqueezeKeyTypeTrait::Key; + +} // namespace + +WasiCryptoExpect stateSqueezeKey(StateVariant &StateVariant, + Algorithm KeyAlg) noexcept { + return std::visit( + [](auto &State, auto Alg) noexcept -> WasiCryptoExpect { + if constexpr (std::is_same_v< + GetSqueezeKeyType>, + typename decltype(Alg)::Key>) { + return State.squeezeKey().map([](auto &&Key) { + return KeyVariant{std::forward(Key)}; + }); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); + } + }, + StateVariant, KeyAlg); +} + +WasiCryptoExpect +stateMaxTagLen(const StateVariant &StateVariant) noexcept { + return std::visit( + [](const auto &State) noexcept { return State.maxTagLen(); }, + StateVariant); +} + +WasiCryptoExpect stateEncrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept { + return std::visit( + [=](auto &State) noexcept -> WasiCryptoExpect { + return State.maxTagLen() + .and_then([DataSize = Data.size()](size_t TagLen) noexcept { + return checkedAdd(DataSize, TagLen); + }) + .and_then([Out, Data, &State](size_t ActualDataLen) noexcept + -> WasiCryptoExpect { + ensureOrReturn(Out.size() == ActualDataLen, + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return State.encrypt(Out, Data); + }); + }, + StateVariant); +} + +WasiCryptoExpect stateEncryptDetached(StateVariant &StateVariant, + Span Out, + Span Data) noexcept { + ensureOrReturn(Data.size() == Out.size(), __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return std::visit( + [=](auto &State) noexcept { return State.encryptDetached(Out, Data); }, + StateVariant); +} + +WasiCryptoExpect stateDecrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept { + return std::visit( + [=](auto &State) noexcept -> WasiCryptoExpect { + return State.maxTagLen() + .and_then([OutSize = Out.size()](size_t TagLen) noexcept { + return checkedAdd(OutSize, TagLen); + }) + .and_then([Out, Data, &State](size_t ActualOutLen) noexcept + -> WasiCryptoExpect { + ensureOrReturn(Data.size() == ActualOutLen, + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return State.decrypt(Out, Data); + }); + }, + StateVariant); +} + +WasiCryptoExpect +stateDecryptDetached(StateVariant &StateVariant, Span Out, + Span Data, + Span RawTag) noexcept { + ensureOrReturn(Data.size() == Out.size(), __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return std::visit( + [=](auto &State) noexcept { + return State.decryptDetached(Out, Data, RawTag); + }, + StateVariant); +} + +WasiCryptoExpect stateRatchet(StateVariant &StateVariant) noexcept { + return std::visit([](auto &State) noexcept { return State.ratchet(); }, + StateVariant); +} + +WasiCryptoExpect +stateClone(const StateVariant &ClonedStateVariant) noexcept { + return std::visit( + [](const auto &ClonedState) noexcept -> WasiCryptoExpect { + return ClonedState.clone().map([](auto &&NewState) noexcept { + return StateVariant{std::forward(NewState)}; + }); + }, + ClonedStateVariant); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/state.h b/plugins/wasi_crypto/symmetric/state.h new file mode 100644 index 00000000..f681f7ce --- /dev/null +++ b/plugins/wasi_crypto/symmetric/state.h @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/state.h - Symmetric State --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the symmetric state related classes, and provides a +/// unified interface which can be used to implement the algorithm operations. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/key.h" +#include "symmetric/registered.h" +#include "symmetric/tag.h" +#include "utils/error.h" + +#include "common/span.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// State created from a key, and performs symmetric operations using the +/// underlying algorithms. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#state +using StateVariant = RegistedAlg::State; + +WasiCryptoExpect +openState(Algorithm Alg, OptionalRef OptKeyVariant, + OptionalRef OptOptions) noexcept; + +WasiCryptoExpect stateOptionsGet(const StateVariant &StateVariant, + std::string_view Name, + Span Value) noexcept; + +WasiCryptoExpect stateOptionsGetU64(const StateVariant &StateVariant, + std::string_view Name) noexcept; + +/// Absorb data into the state. +WasiCryptoExpect stateAbsorb(StateVariant &StateVariant, + Span Data) noexcept; + +/// Squeeze bytes from the state. +WasiCryptoExpect stateSqueeze(StateVariant &StateVariant, + Span Out) noexcept; + +/// Compute and return a tag for all the data injected into the state so far. +WasiCryptoExpect stateSqueezeTag(StateVariant &StateVariant) noexcept; + +/// Use the current state to produce a key for a target algorithm. +WasiCryptoExpect stateSqueezeKey(StateVariant &StateVariant, + Algorithm KeyAlg) noexcept; + +/// Encrypt data with an attached tag. +WasiCryptoExpect stateEncrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept; + +/// Encrypt data and return the ciphertext and the authentication tag +/// separately. +WasiCryptoExpect stateEncryptDetached(StateVariant &StateVariant, + Span Out, + Span Data) noexcept; + +/// Decrypt a ciphertext with an attached tag. +WasiCryptoExpect stateDecrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept; + +/// Verify an authentication tag and decrypt the corresponding ciphertext if +/// the verification passes. +WasiCryptoExpect +stateDecryptDetached(StateVariant &StateVariant, Span Out, + Span Data, + Span RawTag) noexcept; + +/// Returns the length required to encode the authentication tag and optional +/// padding bytes. +WasiCryptoExpect +stateMaxTagLen(const StateVariant &StateVariant) noexcept; + +/// Make the state impossible to recover the previous state. +WasiCryptoExpect stateRatchet(StateVariant &StateVariant) noexcept; + +/// Clone the state. +WasiCryptoExpect +stateClone(const StateVariant &ClonedStateVariant) noexcept; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/tag.cpp b/plugins/wasi_crypto/symmetric/tag.cpp new file mode 100644 index 00000000..3998aee2 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/tag.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "symmetric/tag.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +WasiCryptoExpect Tag::verify(Span RawTag) const noexcept { + ensureOrReturn(!CRYPTO_memcmp(RawTag.data(), Data.data(), RawTag.size()), + __WASI_CRYPTO_ERRNO_INVALID_TAG); + + return {}; +} + +WasiCryptoExpect Tag::pull(Span Raw) const noexcept { + if (Raw.size() > Data.size()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_LENGTH); + } + if (Raw.size() < Data.size()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_OVERFLOW); + } + + std::copy(Data.begin(), Data.end(), Raw.begin()); + return Data.size(); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/tag.h b/plugins/wasi_crypto/symmetric/tag.h new file mode 100644 index 00000000..6ecc26a4 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/tag.h @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/tag.h - Symmetric Tag class ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Symmetric Tag definition. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/error.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Authentication tag that can be verified without channels using the provided +/// APIs. Very small and no streaming. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#authentication-tags +class Tag { +public: + Tag(Tag &&Data) noexcept = default; + Tag &operator=(Tag &&Data) noexcept = default; + Tag(const Tag &Data) noexcept = delete; + Tag &operator=(const Tag &Data) noexcept = delete; + + Tag(SecretVec &&Data) noexcept : Data(std::move(Data)) {} + + size_t len() const noexcept { return Data.size(); } + + /// The function MUST return `__WASI_CRYPTO_ERRNO_INVALID_TAG` if the + /// tags do not match. + WasiCryptoExpect verify(Span RawTag) const noexcept; + + WasiCryptoExpect pull(Span Raw) const noexcept; + +private: + SecretVec Data; +}; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/error.h b/plugins/wasi_crypto/utils/error.h new file mode 100644 index 00000000..24018755 --- /dev/null +++ b/plugins/wasi_crypto/utils/error.h @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/error.h - Error definition -----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the wasi-crypto error handling related functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "common/errcode.h" +#include "common/expected.h" +#include "wasi_crypto/api.hpp" + +#include +#include + +/// Ensure the Expr is true or return ErrorCode. +#define ensureOrReturn(Expr, ErrorCode) \ + do { \ + if (!(Expr)) { \ + return WasiCryptoUnexpect((ErrorCode)); \ + } \ + } while (0) + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// Type aliasing for Expected. +template +using WasiCryptoExpect = Expected; + +/// Helper function for Unexpected. +constexpr auto WasiCryptoUnexpect(__wasi_crypto_errno_e_t Val) noexcept { + return Unexpected<__wasi_crypto_errno_e_t>(Val); +} +template +constexpr auto WasiCryptoUnexpect(const WasiCryptoExpect &Val) noexcept { + return Unexpected<__wasi_crypto_errno_e_t>(Val.error()); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/evp_wrapper.cpp b/plugins/wasi_crypto/utils/evp_wrapper.cpp new file mode 100644 index 00000000..84842220 --- /dev/null +++ b/plugins/wasi_crypto/utils/evp_wrapper.cpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "utils/evp_wrapper.h" +#include "utils/error.h" + +#include +#include + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +namespace { + +BioPtr createBioFromSpan(Span Encoded) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + if (!Bio) { + return nullptr; + } + + size_t Size; + if (!BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size) || + Size != Encoded.size()) { + return nullptr; + } + return Bio; +} + +template +WasiCryptoExpect writeKeyToBio(EVP_PKEY *Key, WriteFunc &&Func) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + opensslCheck(Func(Bio.get(), Key)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + + T Ret(Mem->length); + size_t Size; + if (BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size) && + Size == Ret.size()) { + return Ret; + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); +} + +} // namespace + +EVP_PKEY *pemReadPUBKEY(Span Encoded) { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { + return nullptr; + } + return PEM_read_bio_PUBKEY(Bio.get(), nullptr, nullptr, nullptr); +} + +WasiCryptoExpect> pemWritePUBKEY(EVP_PKEY *Key) { + return writeKeyToBio>( + Key, [](BIO *B, EVP_PKEY *K) { return PEM_write_bio_PUBKEY(B, K); }); +} + +EVP_PKEY *pemReadPrivateKey(Span Encoded) { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { + return nullptr; + } + return PEM_read_bio_PrivateKey(Bio.get(), nullptr, nullptr, nullptr); +} + +WasiCryptoExpect pemWritePrivateKey(EVP_PKEY *Key) { + return writeKeyToBio(Key, [](BIO *B, EVP_PKEY *K) { + return PEM_write_bio_PrivateKey(B, K, nullptr, nullptr, 0, nullptr, + nullptr); + }); +} + +EVP_PKEY *d2iPUBKEY(Span Encoded) { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { + return nullptr; + } + return d2i_PUBKEY_bio(Bio.get(), nullptr); +} + +WasiCryptoExpect> i2dPUBKEY(EVP_PKEY *Key) { + return writeKeyToBio>(Key, i2d_PUBKEY_bio); +} + +EVP_PKEY *d2iPrivateKey(Span Encoded) { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { + return nullptr; + } + return d2i_PrivateKey_bio(Bio.get(), nullptr); +} + +WasiCryptoExpect i2dPrivateKey(EVP_PKEY *Key) { + return writeKeyToBio(Key, i2d_PrivateKey_bio); +} + +ECDSA_SIG *d2iEcdsaSig(Span Encoded) { + if (Encoded.size() > static_cast(std::numeric_limits::max())) { + return nullptr; + } + auto *Data = Encoded.data(); + return d2i_ECDSA_SIG(nullptr, &Data, static_cast(Encoded.size())); +} + +WasiCryptoExpect> i2dEcdsaSig(ECDSA_SIG *Sig) { + int SigSize = i2d_ECDSA_SIG(Sig, nullptr); + ensureOrReturn(SigSize >= 0, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + std::vector Res(static_cast(SigSize)); + + auto *Data = Res.data(); + auto NewSize = i2d_ECDSA_SIG(Sig, &Data); + ensureOrReturn(NewSize == SigSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +ECDSA_SIG *o2iEcdsaSig(Span Encoded) { + if (Encoded.size() > static_cast(std::numeric_limits::max())) { + return nullptr; + } + int EncodedSize = static_cast(Encoded.size()); + BnPtr R{BN_bin2bn(Encoded.data(), EncodedSize / 2, nullptr)}; + BnPtr S{ + BN_bin2bn(Encoded.data() + EncodedSize / 2, EncodedSize / 2, nullptr)}; + EcdsaSigPtr Sig{ECDSA_SIG_new()}; + if (!ECDSA_SIG_set0(Sig.get(), R.release(), S.release())) { + return nullptr; + } + + return Sig.release(); +} + +WasiCryptoExpect> i2oEcdsaSig(ECDSA_SIG *Sig) { + auto *R = ECDSA_SIG_get0_r(Sig); + auto *S = ECDSA_SIG_get0_s(Sig); + auto RSize = static_cast(BN_num_bytes(R)); + auto SSize = static_cast(BN_num_bytes(S)); + std::vector Res(RSize + SSize); + opensslCheck(BN_bn2bin(R, Res.data())); + opensslCheck(BN_bn2bin(S, Res.data() + RSize)); + + return Res; +} + +SharedEvpPkey::~SharedEvpPkey() noexcept { + if (Pkey != nullptr) { + EVP_PKEY_free(Pkey); + Pkey = nullptr; + } +} + +SharedEvpPkey::SharedEvpPkey(const SharedEvpPkey &Rhs) noexcept + : Pkey(Rhs.Pkey) { + if (Rhs.Pkey != nullptr) { + EVP_PKEY_up_ref(Pkey); + } +} + +SharedEvpPkey::SharedEvpPkey(SharedEvpPkey &&Rhs) noexcept : Pkey(Rhs.Pkey) { + Rhs.Pkey = nullptr; +} + +EVP_PKEY *SharedEvpPkey::get() const noexcept { return Pkey; } + +SharedEvpPkey::operator bool() const noexcept { return Pkey != nullptr; } + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/evp_wrapper.h b/plugins/wasi_crypto/utils/evp_wrapper.h new file mode 100644 index 00000000..9827e6f8 --- /dev/null +++ b/plugins/wasi_crypto/utils/evp_wrapper.h @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/evp_wrapper.h - Evp Wrapper ----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definitions of OpenSSL EVP-related functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" +#include "utils/secret_vec.h" + +#include "common/span.h" +#include "common/spdlog.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// Helper alias for OpenSSL relative struct. +template using Deleter = std::integral_constant; + +template +using OpenSSLUniquePtr = std::unique_ptr>; + +using EvpMdCtxPtr = OpenSSLUniquePtr; +using EvpPkeyCtxPtr = OpenSSLUniquePtr; +using EvpCipherCtxPtr = OpenSSLUniquePtr; +using EvpPkeyPtr = OpenSSLUniquePtr; +using BioPtr = OpenSSLUniquePtr; +using EcKeyPtr = OpenSSLUniquePtr; +using BnPtr = OpenSSLUniquePtr; +using EcPointPtr = OpenSSLUniquePtr; +using EcdsaSigPtr = OpenSSLUniquePtr; +using RsaPtr = OpenSSLUniquePtr; + +/// OpenSSL functions always return 1 for success and 0/NULL for failure. This +/// is used to reduce repeated checks. +#ifdef NDEBUG +#define opensslCheck(Cond) \ + do { \ + if (!(Cond)) { \ + using namespace std::literals; \ + ERR_print_errors_cb( \ + [](const char *ErrStr, size_t ErrLen, void *) { \ + spdlog::error("{}"sv, std::string_view(ErrStr, ErrLen)); \ + return 1; \ + }, \ + nullptr); \ + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); \ + } \ + } while (false) +#else +#define opensslCheck(Cond) \ + do { \ + if (!(Cond)) { \ + using namespace std::literals; \ + ERR_print_errors_cb( \ + [](const char *ErrStr, size_t ErrLen, void *) { \ + spdlog::error("{}"sv, std::string_view(ErrStr, ErrLen)); \ + return 1; \ + }, \ + nullptr); \ + OPENSSL_die("assertion failed: " #Cond, __FILE__, __LINE__); \ + } \ + } while (false) +#endif + +/// OpenSSL encoding parse api is too confusing, simplify them. +/// For example, `PEM_read_bio_PUBKEY` is equal to `pemReadPUBKEY`. + +// ------------------------------------------------------------------------- // +EVP_PKEY *pemReadPUBKEY(Span Encoded); + +WasiCryptoExpect> pemWritePUBKEY(EVP_PKEY *Key); + +EVP_PKEY *pemReadPrivateKey(Span Encoded); + +WasiCryptoExpect pemWritePrivateKey(EVP_PKEY *Key); + +EVP_PKEY *d2iPUBKEY(Span Encoded); + +WasiCryptoExpect> i2dPUBKEY(EVP_PKEY *Key); + +EVP_PKEY *d2iPrivateKey(Span Encoded); + +WasiCryptoExpect i2dPrivateKey(EVP_PKEY *Key); + +ECDSA_SIG *d2iEcdsaSig(Span Encoded); + +WasiCryptoExpect> i2dEcdsaSig(ECDSA_SIG *Sig); + +// ------------------------------------------------------------------------- // +// Transform raw represent ecdsa ( r | s) to ECDSA_SIG. Need to check `nullptr`. +ECDSA_SIG *o2iEcdsaSig(Span Encoded); + +// Transform ECDSA_SIG to raw representation (r | s). +WasiCryptoExpect> i2oEcdsaSig(ECDSA_SIG *Sig); + +// This is a wrapper for EVP_PKEY. Since EVP_PKEY internally uses locks to +// guarantee thread-safe `EVP_PKEY_up_ref` (you will find them in +// crypto/evp/p_lib.c in OpenSSL v1.1.1), using shared_ptr for `EVP_PKEY` is +// wasteful. It only provides limited functions for correct use. +class SharedEvpPkey { +public: + SharedEvpPkey(EvpPkeyPtr Pkey) noexcept : Pkey(Pkey.release()) {} + ~SharedEvpPkey() noexcept; + + SharedEvpPkey(const SharedEvpPkey &Rhs) noexcept; + SharedEvpPkey(SharedEvpPkey &&Rhs) noexcept; + // Assigning to an existing SharedEvpPkey is not thread-safe, so delete the + // assignment operators. + SharedEvpPkey &operator=(const SharedEvpPkey &Rhs) noexcept = delete; + SharedEvpPkey &operator=(SharedEvpPkey &&Rhs) noexcept = delete; + + EVP_PKEY *get() const noexcept; + + explicit operator bool() const noexcept; + +private: + EVP_PKEY *Pkey; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/handles_manager.h b/plugins/wasi_crypto/utils/handles_manager.h new file mode 100644 index 00000000..f7f072ea --- /dev/null +++ b/plugins/wasi_crypto/utils/handles_manager.h @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/handles_manager.h --------------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the class definitions of the WasiCrypto HandlesManager. +/// It controls the handle and the inner states. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace detail { + +/// The Handles Manager base class. +/// +/// @tparam HandleType This is the handle type. It must be 32 bits wide. +/// @tparam ManagerType The managed content type. +/// +/// HandlesManager uses a handle as the index to represent the managed contents. +/// +/// Referenced from: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/implementations/hostcalls/rust/src/handles.rs +template class BaseHandlesManager { +public: + BaseHandlesManager(const BaseHandlesManager &) noexcept = delete; + BaseHandlesManager &operator=(const BaseHandlesManager &) noexcept = delete; + BaseHandlesManager(BaseHandlesManager &&) noexcept = delete; + BaseHandlesManager &operator=(BaseHandlesManager &&) noexcept = delete; + + /// @param TypeID A unique number + explicit BaseHandlesManager(uint8_t TypeID) noexcept + : LastHandle{TypeID, 0} {} + + WasiCryptoExpect close(HandleType Handle) noexcept { + std::unique_lock Lock{Mutex}; + + if (!Map.erase(HandleWrapper(Handle))) + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_CLOSED); + + return {}; + } + + /// Construct a new manager. + template + WasiCryptoExpect registerManager(Args &&...Manager) noexcept { + std::unique_lock Lock{Mutex}; + + // Find an available handle and emplace it. + // Assume that LastHandle is 0x01000000, NextHandle is 0x01000001. + auto NextHandle = LastHandle.nextHandle(); + while (true) { + // Try to emplace NextHandle. + if (Map.try_emplace(NextHandle, std::forward(Manager)...).second) { + // If this succeeds, the emplacement indicates that NextHandle does not + // exist in the managed content. Update the last handle and return it. + LastHandle = NextHandle; + return LastHandle.Handle; + } + // Otherwise, the NextHandle map already contains content. Advance + // NextHandle and loop. + NextHandle = NextHandle.nextHandle(); + + // If, after looping many times (2^24 - 1), we get 0x01000000 again, the + // hash map is full. + if (NextHandle == LastHandle) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_TOO_MANY_HANDLES); + } + } + } + +protected: + /// The handle internal representation: [-TypeID-|------CurrentNumber------] + union HandleWrapper { + static_assert(sizeof(HandleType) == 4, "HandleType must be 4 byte"); + HandleWrapper(uint8_t TypeID, uint32_t CurrentNumber) noexcept + : TypeID(TypeID), CurrentNumber(CurrentNumber) {} + explicit HandleWrapper(HandleType Handle) : Handle(Handle) {} + + HandleWrapper nextHandle() noexcept { + return {TypeID, static_cast(CurrentNumber + 1)}; + } + + struct Hash { + size_t operator()(const HandleWrapper &Wrapper) const noexcept { + return static_cast(Wrapper.Handle); + } + }; + + bool operator==(const HandleWrapper &Wrapper) const noexcept { + return Wrapper.Handle == this->Handle; + } + + struct { + uint8_t TypeID : 8; + uint32_t CurrentNumber : 24; + }; + HandleType Handle; + }; + + std::shared_mutex Mutex; + HandleWrapper LastHandle; + std::unordered_map + Map; +}; + +template struct IsVariantMember; +template +struct IsVariantMember> + : public std::disjunction...> {}; + +} // namespace detail + +/// ManagerType need reference count. +template , bool> = + false> +class RcHandlesManager + : public detail::BaseHandlesManager { + using HandleWrapper = + typename detail::BaseHandlesManager::HandleWrapper; + +public: + using detail::BaseHandlesManager::BaseHandlesManager; + + /// Get the return copy. + WasiCryptoExpect get(HandleType Handle) noexcept { + std::shared_lock Lock{this->Mutex}; + + auto HandleValue = this->Map.find(HandleWrapper(Handle)); + if (HandleValue == this->Map.end()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return HandleValue->second; + } + + /// Get as different variant type. + template + WasiCryptoExpect getAs(HandleType Handle) noexcept { + std::shared_lock Lock{this->Mutex}; + + auto HandleValue = this->Map.find(HandleWrapper(Handle)); + if (HandleValue == this->Map.end()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return std::visit( + [](auto &&Value) noexcept -> WasiCryptoExpect { + using T = std::decay_t; + if constexpr (detail::IsVariantMember::value) { + + return Value; + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + }, + HandleValue->second); + } +}; + +/// ManagerType just use reference. +template +class RefHandlesManager + : public detail::BaseHandlesManager { + using HandleWrapper = + typename detail::BaseHandlesManager::HandleWrapper; + +public: + using detail::BaseHandlesManager::BaseHandlesManager; + + /// Get the return reference. + WasiCryptoExpect> + get(HandleType Handle) noexcept { + std::shared_lock Lock{this->Mutex}; + + auto HandleValue = this->Map.find(HandleWrapper(Handle)); + if (HandleValue == this->Map.end()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return HandleValue->second; + } +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/hostfunction.cpp b/plugins/wasi_crypto/utils/hostfunction.cpp new file mode 100644 index 00000000..81c078a9 --- /dev/null +++ b/plugins/wasi_crypto/utils/hostfunction.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "utils/hostfunction.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +using namespace std::literals; + +namespace { +std::string toUpper(std::string_view Name) noexcept { + std::string Ret{Name}; + std::transform(Ret.begin(), Ret.end(), Ret.begin(), + [](char C) { return std::toupper(C); }); + return Ret; +} +} // namespace + +WasiCryptoExpect +tryFrom(__wasi_algorithm_type_e_t AlgType, + std::string_view RawAlgStr) noexcept { + std::string AlgStr = toUpper(RawAlgStr); + // Delegate to sig and kx. + switch (AlgType) { + case __WASI_ALGORITHM_TYPE_SIGNATURES: { + return tryFrom(AlgStr).map([](auto Alg) noexcept { + return std::visit( + [](auto Factory) noexcept -> AsymmetricCommon::Algorithm { + return Factory; + }, + Alg); + }); + } + case __WASI_ALGORITHM_TYPE_KEY_EXCHANGE: { + return tryFrom(AlgStr).map([](auto Alg) noexcept { + return std::visit( + [](auto Factory) noexcept -> AsymmetricCommon::Algorithm { + return Factory; + }, + Alg); + }); + } + case __WASI_ALGORITHM_TYPE_SYMMETRIC: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + default: + assumingUnreachable(); + } +} + +template <> +WasiCryptoExpect tryFrom(std::string_view RawAlgStr) noexcept { + using namespace Kx; + std::string AlgStr = toUpper(RawAlgStr); + if (AlgStr == "X25519"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "P256-SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "P384-SHA384"sv) { + return Algorithm{std::in_place_type}; + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); +} + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept { + using namespace Symmetric; + std::string AlgStr = toUpper(RawAlgStr); + if (AlgStr == "SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "SHA-512/256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HMAC/SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HMAC/SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXPAND/SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXTRACT/SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXPAND/SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXTRACT/SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "AES-128-GCM"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "AES-256-GCM"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "CHACHA20-POLY1305"sv) { + return Algorithm{std::in_place_type}; + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); +} + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept { + using namespace Signatures; + std::string AlgStr = toUpper(RawAlgStr); + if (AlgStr == "ECDSA_P256_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "ECDSA_K256_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "ECDSA_P384_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "ED25519"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_2048_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_2048_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_2048_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_3072_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_3072_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_4096_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_2048_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_2048_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_2048_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_3072_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_3072_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_4096_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h new file mode 100644 index 00000000..460130eb --- /dev/null +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/hostfunc.h - HostFunction class ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the HostFunction classes and some helper functions +/// interact with wasi. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" +#include "symmetric/registered.h" +#include "utils/error.h" + +#include "runtime/callingframe.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// The wasi-crypto's base HostFunction class, every wasi-crypto HostFunction +/// should inherit from this class. +template class HostFunction : public Runtime::HostFunction { +public: + HostFunction(Context &Ctx) : Runtime::HostFunction(0), Ctx(Ctx) {} + +protected: + Context &Ctx; +}; + +/// Cast wasi enum to c++ enum. +template constexpr WasiCryptoExpect cast(uint64_t) noexcept; + +template struct WasiRawType { + using Type = std::underlying_type_t; +}; +template <> struct WasiRawType { + using Type = uint8_t; +}; +template <> struct WasiRawType { + using Type = uint16_t; +}; +template <> struct WasiRawType { + using Type = uint32_t; +}; +template <> struct WasiRawType { + using Type = uint64_t; +}; + +template using WasiRawTypeT = typename WasiRawType::Type; + +template <> +constexpr WasiCryptoExpect<__wasi_algorithm_type_e_t> +cast(uint64_t AlgType) noexcept { + switch (static_cast>(AlgType)) { + case __WASI_ALGORITHM_TYPE_SIGNATURES: + case __WASI_ALGORITHM_TYPE_SYMMETRIC: + case __WASI_ALGORITHM_TYPE_KEY_EXCHANGE: + return static_cast<__wasi_algorithm_type_e_t>(AlgType); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_keypair_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_KEYPAIR_ENCODING_RAW: + case __WASI_KEYPAIR_ENCODING_PKCS8: + case __WASI_KEYPAIR_ENCODING_PEM: + case __WASI_KEYPAIR_ENCODING_LOCAL: + return static_cast<__wasi_keypair_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_publickey_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_PUBLICKEY_ENCODING_RAW: + case __WASI_PUBLICKEY_ENCODING_PKCS8: + case __WASI_PUBLICKEY_ENCODING_PEM: + case __WASI_PUBLICKEY_ENCODING_SEC: + case __WASI_PUBLICKEY_ENCODING_LOCAL: + return static_cast<__wasi_publickey_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_secretkey_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_SECRETKEY_ENCODING_RAW: + case __WASI_SECRETKEY_ENCODING_PKCS8: + case __WASI_SECRETKEY_ENCODING_PEM: + case __WASI_SECRETKEY_ENCODING_SEC: + case __WASI_SECRETKEY_ENCODING_LOCAL: + return static_cast<__wasi_secretkey_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_signature_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_SIGNATURE_ENCODING_RAW: + case __WASI_SIGNATURE_ENCODING_DER: + return static_cast<__wasi_signature_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +/// Cast c++ size_t to wasi size_t. +constexpr WasiCryptoExpect<__wasi_size_t> toWasiSize(size_t Size) noexcept { + ensureOrReturn(Size <= std::numeric_limits<__wasi_size_t>::max(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return static_cast<__wasi_size_t>(Size); +} + +/// Convert string_view to inner Alg representation. +template WasiCryptoExpect tryFrom(std::string_view) noexcept; + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept; + +WasiCryptoExpect +tryFrom(__wasi_algorithm_type_e_t AlgType, std::string_view RawAlgStr) noexcept; + +template <> +WasiCryptoExpect tryFrom(std::string_view RawAlgStr) noexcept; + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept; + +/// Check exist or return `_algorithm_failure`. +#define checkExist(Expr) \ + do { \ + if (unlikely(!(Expr))) { \ + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; \ + } \ + } while (0) + +/// Check Span exist or return `_algorithm_failure`. +#define checkRangeExist(Expr, Size) \ + do { \ + if (unlikely((Expr).size() != (Size))) { \ + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; \ + } \ + } while (0) + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/optional.h b/plugins/wasi_crypto/utils/optional.h new file mode 100644 index 00000000..0752ed5e --- /dev/null +++ b/plugins/wasi_crypto/utils/optional.h @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/handles_manager.h - OptionalRef ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of OptionalRef and some helper +/// functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +template using OptionalRef = T *; + +namespace detail { + +template struct IsOptional : std::false_type {}; +template struct IsOptional> : std::true_type { + using Type = T; +}; + +template struct IsExpected : std::false_type {}; +template struct IsExpected> : std::true_type { + using Type = T; +}; + +template struct IsOptionalRef : std::false_type {}; +template struct IsOptionalRef> : std::true_type { + using Type = T; +}; + +template struct IsExpectedOptionalRef : std::false_type {}; +template +struct IsExpectedOptionalRef>> + : std::true_type { + using Type = T; +}; + +} // namespace detail + +/// std::optional -> (T -> WasiCrypto) -> +/// WasiCryptoExpect> +template +inline auto mapAndTransposeOptional(const __wasi_opt_options_t Optional, + F &&Function) noexcept + -> std::enable_if_t< + detail::IsExpected(Function), Optional.u.some))>>::value, + WasiCryptoExpect(Function), Optional.u.some))>>::Type>>> { + if (Optional.tag == __WASI_OPT_OPTIONS_U_NONE) + return std::nullopt; + + return std::invoke(std::forward(Function), Optional.u.some); +} + +template +inline auto mapAndTransposeOptional(const __wasi_opt_symmetric_key_t Optional, + F &&Function) noexcept + -> std::enable_if_t< + detail::IsExpected(Function), Optional.u.some))>>::value, + WasiCryptoExpect(Function), Optional.u.some))>>::Type>>> { + if (Optional.tag == __WASI_OPT_SYMMETRIC_KEY_U_NONE) + return std::nullopt; + + return std::invoke(std::forward(Function), Optional.u.some); +} + +/// std::optional -> (T -> WasiCryptoExpect>) -> +/// WasiCryptoExpect> +template < + typename O, typename F, + typename = std::enable_if_t>::value>> +inline auto transposeOptionalToRef(O &&Optional, F &&Function) noexcept + -> WasiCryptoExpect(Function), *std::forward(Optional)))>>::Type>> { + if (!Optional) + return nullptr; + + return std::invoke(std::forward(Function), *Optional); +} + +/// OptionalRef -> (T -> WasiCryptoExpect>) -> +/// WasiCryptoExpect> +template < + typename O, typename F, + typename = std::enable_if_t>::value>> +inline auto transposeOptionalRef(O &&Optional, F &&Function) noexcept + -> WasiCryptoExpect(Function), *std::forward(Optional)))>>::Type>> { + if (!Optional) + return nullptr; + + return std::invoke(std::forward(Function), *Optional); +} + +/// std::optional -> OptionalRef +template >::value>> +inline auto asOptionalRef(O &&Optional) noexcept + -> OptionalRef>::Type> { + if (!Optional) + return nullptr; + + return &*Optional; +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/secret_vec.h b/plugins/wasi_crypto/utils/secret_vec.h new file mode 100644 index 00000000..9e2d2f1b --- /dev/null +++ b/plugins/wasi_crypto/utils/secret_vec.h @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/secret_vec.h - Secret Vec def --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the secret vec. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include +#include + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// A vector wrapper, but swipe the secret key info on destroy. +class SecretVec { +public: + SecretVec(const SecretVec &) = default; + SecretVec &operator=(const SecretVec &) = default; + SecretVec &operator=(SecretVec &&) noexcept = default; + SecretVec(SecretVec &&) noexcept = default; + + SecretVec(Span Data) noexcept + : Data(Data.begin(), Data.end()) {} + + SecretVec(size_t Size) noexcept : Data(Size) {} + + ~SecretVec() noexcept { OPENSSL_cleanse(Data.data(), Data.size()); } + + auto begin() noexcept { return Data.begin(); } + auto begin() const noexcept { return Data.begin(); } + + auto end() noexcept { return Data.end(); } + auto end() const noexcept { return Data.end(); } + + auto size() const noexcept { return Data.size(); } + + auto data() noexcept { return Data.data(); } + auto data() const noexcept { return Data.data(); } + + using difference_type = std::vector::difference_type; + + /// Generate random size vector. Notice that the size shouldn't beyond + /// std::numeric_limits::max() because of the limitations of openssl. + template static WasiCryptoExpect random() noexcept { + static_assert( + Size <= std::numeric_limits::max(), + "Random key size shouldn't beyond std::numeric_limits::max()"); + + SecretVec Res(Size); + ensureOrReturn(RAND_bytes(Res.data(), static_cast(Size)), + __WASI_CRYPTO_ERRNO_RNG_ERROR); + return Res; + } + +private: + std::vector Data; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/CMakeLists.txt b/plugins/wasi_http/CMakeLists.txt new file mode 100644 index 00000000..45bc0030 --- /dev/null +++ b/plugins/wasi_http/CMakeLists.txt @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +include(FetchContent) +FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git + GIT_TAG 3b15fa82ea74739b574d705fea44959b58142eb8) +FetchContent_MakeAvailable(cpr) + +wasmedge_add_library(wasmedgePluginWasiHttp + SHARED + env.cpp + func.cpp + module.cpp +) + +target_compile_options(wasmedgePluginWasiHttp + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiHttp + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/thirdparty +) + +target_link_libraries(wasmedgePluginWasiHttp + PUBLIC + cpr::cpr +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiHttp + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiHttp + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasiHttp + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_http/README.md b/plugins/wasi_http/README.md new file mode 100644 index 00000000..0074d4b8 --- /dev/null +++ b/plugins/wasi_http/README.md @@ -0,0 +1,3 @@ +# wasi_http + +This is corresponding to [wasi-http preview2](https://github.com/WebAssembly/wasi-http), but a very beginning implementation, for now it's created to test component model. diff --git a/plugins/wasi_http/base.h b/plugins/wasi_http/base.h new file mode 100644 index 00000000..48243ec7 --- /dev/null +++ b/plugins/wasi_http/base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "common/errcode.h" +#include "runtime/component/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template +class WasiHttp : public Runtime::Component::HostFunction { +public: + WasiHttp(WasiHttpEnvironment &HostEnv) + : Runtime::Component::HostFunction(), Env(HostEnv) {} + +protected: + WasiHttpEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/env.cpp b/plugins/wasi_http/env.cpp new file mode 100644 index 00000000..b0743d14 --- /dev/null +++ b/plugins/wasi_http/env.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "env.h" +#include "module.h" + +namespace WasmEdge { +namespace Host { + +WasiHttpEnvironment::WasiHttpEnvironment() noexcept {} + +namespace { + +Runtime::Instance::ComponentInstance * +create(const Plugin::PluginComponent::ComponentDescriptor *) noexcept { + return new WasiHttpModule(); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_http", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 0, + .ModuleDescriptions = {}, + .ComponentCount = 1, + .ComponentDescriptions = + (Plugin::PluginComponent::ComponentDescriptor[]){ + { + .Name = "wasi:http/test", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/env.h b/plugins/wasi_http/env.h new file mode 100644 index 00000000..2df43a2b --- /dev/null +++ b/plugins/wasi_http/env.h @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC +#pragma once + +#include "plugin/plugin.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasiHttpEnvironment { +public: + WasiHttpEnvironment() noexcept; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/func.cpp b/plugins/wasi_http/func.cpp new file mode 100644 index 00000000..bb101044 --- /dev/null +++ b/plugins/wasi_http/func.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func.h" +#include "common/defines.h" +#include "common/errcode.h" + +#include +#include +#include +#include + +using namespace std::literals; + +namespace WasmEdge { +namespace Host { + +Expect WasiHttpPrint::body(std::string S) { + spdlog::info("[WASI-HTTP] print: {}"sv, S); + return {}; +} + +Expect WasiHttpGet::body(std::string URI) { + spdlog::info("[WASI-HTTP] URI: {}"sv, URI); + cpr::Response Res = cpr::Get( + cpr::Url{URI}, cpr::Authentication{"user", "pass", cpr::AuthMode::BASIC}); + spdlog::info("[WASI-HTTP] status: {}"sv, Res.status_code); + + return std::move(Res.text); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/func.h b/plugins/wasi_http/func.h new file mode 100644 index 00000000..855f3855 --- /dev/null +++ b/plugins/wasi_http/func.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { + +class WasiHttpPrint : public WasiHttp { +public: + WasiHttpPrint(WasiHttpEnvironment &HostEnv) : WasiHttp(HostEnv) {} + Expect body(std::string Str); +}; + +class WasiHttpGet : public WasiHttp { +public: + WasiHttpGet(WasiHttpEnvironment &HostEnv) : WasiHttp(HostEnv) {} + Expect body(std::string URI); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/module.cpp b/plugins/wasi_http/module.cpp new file mode 100644 index 00000000..49fc4fb6 --- /dev/null +++ b/plugins/wasi_http/module.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiHttpModule::WasiHttpModule() : ComponentInstance("wasi:http/test") { + addHostFunc("http-get", std::make_unique(Env)); + addHostFunc("print", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/module.h b/plugins/wasi_http/module.h new file mode 100644 index 00000000..93bb8776 --- /dev/null +++ b/plugins/wasi_http/module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiHttpModule : public Runtime::Instance::ComponentInstance { +public: + WasiHttpModule(); + + WasiHttpEnvironment &getEnv() { return Env; } + +private: + WasiHttpEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt new file mode 100644 index 00000000..d7be3bcf --- /dev/null +++ b/plugins/wasi_nn/CMakeLists.txt @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasiNN + SHARED + wasinnenv.cpp + wasinnfunc.cpp + wasinnmodule.cpp + wasinn_openvino.cpp + wasinn_openvino_genai.cpp + wasinn_onnx.cpp + wasinn_tf.cpp + wasinn_torch.cpp + wasinn_tfl.cpp + GGML/core/ggml_core.cpp + wasinn_neuralspeed.cpp + wasinn_piper.cpp + wasinn_whisper.cpp + wasinn_chattts.cpp + wasinn_mlx.cpp + wasinn_bitnet.cpp +) + +include(WASINNDeps) +wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) + +set(WASMEDGE_WASI_NN_VERSION "0.1.34" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") + +# Handle the version of the WASI-NN plugin +string(REPLACE "." ";" WASI_NN_VERSION_LIST ${WASMEDGE_WASI_NN_VERSION}) +list(GET WASI_NN_VERSION_LIST 0 WASI_NN_VERSION_MAJOR) +list(GET WASI_NN_VERSION_LIST 1 WASI_NN_VERSION_MINOR) +list(GET WASI_NN_VERSION_LIST 2 WASI_NN_VERSION_PATCH) + +target_compile_definitions(wasmedgePluginWasiNN PRIVATE + WASI_NN_VERSION_MAJOR=${WASI_NN_VERSION_MAJOR} + WASI_NN_VERSION_MINOR=${WASI_NN_VERSION_MINOR} + WASI_NN_VERSION_PATCH=${WASI_NN_VERSION_PATCH} +) +# This foreach iteration handles the additional sources. +# The dependencies are moved to `cmake/WASINNDeps.cmake`. +foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) + string(TOLOWER ${BACKEND} BACKEND) + if(BACKEND STREQUAL "mlx") + target_sources(wasmedgePluginWasiNN + PRIVATE + MLX/prompt/prompt.cpp + MLX/model/llm/transformer.cpp + MLX/model/llm/registry.cpp + MLX/model/gemma3/language.cpp + MLX/model/gemma3/vision.cpp + MLX/model/gemma3/gemma3.cpp + MLX/model/converter.cpp + MLX/model/utils.cpp + MLX/model/vlm_base.cpp + MLX/model/vlm_sampling.cpp + MLX/model/whisper/whisper.cpp + MLX/model/whisper/tokenizer.cpp + MLX/model/whisper/decoding.cpp + MLX/model/whisper_transcribe.cpp + MLX/mlx/base.cpp + MLX/mlx/linear.cpp + MLX/mlx/convolution.cpp + MLX/mlx/positional_encoding.cpp + MLX/mlx/activations.cpp + MLX/mlx/embedding.cpp + MLX/mlx/normalization.cpp + MLX/mlx/transformer.cpp + MLX/mlx/pooling.cpp + MLX/mlx/quantized.cpp + ) + endif() + + if(BACKEND STREQUAL "ggml") + target_sources(wasmedgePluginWasiNN + PRIVATE + GGML/core/ggml_core.cpp + GGML/core/input_processor.cpp + GGML/core/output_generator.cpp + GGML/metadata/metadata_parser.cpp + GGML/compute/compute_engine.cpp + GGML/compute/inference_manager.cpp + GGML/tts/tts_core.cpp + GGML/utils.cpp + ) + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_HIP) + find_package(hip REQUIRED) + find_package(hipblas REQUIRED) + set(GGML_HIP ON CACHE BOOL "Build GGML with HIP" FORCE) + if(DEFINED ENV{GPU_TARGETS}) + set(GPU_TARGETS $ENV{GPU_TARGETS} CACHE STRING "HIP GPU targets") + else() + set(GPU_TARGETS "gfx90c" CACHE STRING "Default for Vega iGPU") + endif() + message(STATUS "Enabling HIP for ggml backend with targets: ${GPU_TARGETS}") + target_compile_definitions(wasmedgePluginWasiNN PRIVATE GGML_HIP=${GGML_HIP} GPU_TARGETS=${GPU_TARGETS}) + target_link_libraries(wasmedgePluginWasiNN PRIVATE hip::host roc::hipblas) + endif() + endif() + if(BACKEND STREQUAL "piper") + if(DEFINED PIPER_ROOT) + find_library(ESPEAK_NG_LIB + NAMES espeak-ng libespeak-ng + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT ESPEAK_NG_LIB) + find_library(ESPEAK_NG_LIB NAMES espeak-ng libespeak-ng) + endif() + + find_library(UCD_LIB + NAMES ucd libucd + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT UCD_LIB) + find_library(UCD_LIB NAMES ucd libucd) + endif() + + set(ESPEAK_TARGETS ${ESPEAK_NG_LIB} ${UCD_LIB}) + else() + set(ESPEAK_TARGETS "") + endif() + + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + onnxruntime + ${ESPEAK_TARGETS} + ) + endif() +endforeach() + +target_compile_options(wasmedgePluginWasiNN + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiNN + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_BUILD_WASI_NN_RPC) + add_definitions(-DWASMEDGE_BUILD_WASI_NN_RPC) + target_include_directories(wasmedgePluginWasiNN + SYSTEM BEFORE PUBLIC ${Protobuf_INCLUDE_DIR} + ) + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + wasiNNRPC + ) +endif() + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasiNN + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_nn/GGML/compute/compute_engine.cpp b/plugins/wasi_nn/GGML/compute/compute_engine.cpp new file mode 100644 index 00000000..5f569d0c --- /dev/null +++ b/plugins/wasi_nn/GGML/compute/compute_engine.cpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "GGML/core/ggml_core.h" +#include "GGML/tts/tts_core.h" +#include "inference_manager.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "compute") + + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + + if (GraphRef.Params.embedding) { + return getEmbedding(GraphRef, CxtRef); + } + + // Evaluate the input tokens. + ErrNo ReturnCode = ErrNo::Success; + if (GraphRef.VisionContext == nullptr) { + // Text only prompt. + ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + } else { + // Multimodal prompt. + llama_pos NewNPos; + int32_t Res = mtmd_helper_eval_chunks( + GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), + GraphRef.VisionInputChunks.get(), CxtRef.NPos, + /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), + /* logits_last */ true, &NewNPos); + CxtRef.NPos = NewNPos; + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "compute: unable to eval the mtmd prompt."sv) + } + } + + // Main prediction loop. + LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) + int64_t NPredict = + CxtRef.Conf.NPredict < 0 ? INT32_MAX : CxtRef.Conf.NPredict; + + while (NPredict-- > 0) { + ReturnCode = sampleOutput(GraphRef, CxtRef); + if (ReturnCode != ErrNo::Success) { + break; + } + } + if (ReturnCode == ErrNo::EndOfSequence) { + ReturnCode = ErrNo::Success; + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: enter main prediction loop...Done"sv) + // End of main prediction loop. + + // TTS: convert output codes to audio file. + if (GraphRef.TextToSpeech) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: convert output codes to audio file."sv) + ReturnCode = codesToSpeech(Env, GraphRef, CxtRef); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, + "compute: failed to convert output codes to audio "sv + "file."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: convert output codes to audio file...Done"sv) + } + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "compute...Done"sv) + return ReturnCode; +} + +Expect computeSingle(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle"sv) + + // New compute single token context. + auto ReturnCode = ErrNo::Success; + if (!CxtRef.ComputeSingleStarted) { + CxtRef.ComputeSingleStarted = true; + + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + + // Evaluate the input tokens. + if (GraphRef.VisionContext == nullptr) { + // Text only prompt. + ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + } else { + // Multimodal prompt. + llama_pos NewNPos; + int32_t Res = mtmd_helper_eval_chunks( + GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), + GraphRef.VisionInputChunks.get(), CxtRef.NPos, + /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), + /* logits_last */ true, &NewNPos); + CxtRef.NPos = NewNPos; + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "compute: unable to eval the mtmd prompt."sv) + } + } + } + + // Main prediction process. + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process"sv) + ReturnCode = sampleOutput(GraphRef, CxtRef, true); + if (ReturnCode != ErrNo::Success) { + CxtRef.ComputeSingleStarted = false; + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process...Done"sv) + // End of main predict process. + + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle...Done"sv) + return ReturnCode; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/compute/inference_manager.cpp b/plugins/wasi_nn/GGML/compute/inference_manager.cpp new file mode 100644 index 00000000..83dd3b06 --- /dev/null +++ b/plugins/wasi_nn/GGML/compute/inference_manager.cpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "inference_manager.h" +#include "GGML/core/ggml_core.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// Fill a batch with tokens (smaller than batch size) and position data. +void fillBatch(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, bool IsLogit = false) { + assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); + assuming(Batch.token != nullptr); + assuming(Batch.pos != nullptr); + assuming(Batch.logits != nullptr); + // Fill the batch with pos information. + Batch.n_tokens = static_cast(Tokens.size()); + for (uint32_t I = 0; I < Tokens.size(); I++) { + Batch.token[I] = Tokens[I]; + Batch.pos[I] = NPos + I; + Batch.logits[I] = false; + } + + // Logits for sampling or the end of inputs. + if (IsLogit) { + Batch.logits[Tokens.size() - 1] = true; + } + + // Move the position. + NPos += static_cast(Tokens.size()); +} + +// Evaluate tokens. Construct the batch from tokens and decode. +ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, + bool IsLogits = false) noexcept { + // End the inference if the context is full. + uint32_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + if (NPos + static_cast(Tokens.size()) > NCtx) { + LOG_INFO( + GraphRef.EnableLog, + "evaluateTokens: the context if full ({} / {} tokens). Please increase your "sv + "context size."sv, + NPos + static_cast(Tokens.size()), NCtx) + return ErrNo::ContextFull; + } + + // Loop for decoding batches. Split tokens by batch size. + for (int I = 0; I < static_cast(Tokens.size()); + I += static_cast(GraphRef.Params.n_batch)) { + int NEval = static_cast(Tokens.size()) - I; + if (NEval > static_cast(GraphRef.Params.n_batch)) { + NEval = static_cast(GraphRef.Params.n_batch); + } + + // Fill the batch with pos information. + fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, + Batch, NPos, + IsLogits && I + NEval >= static_cast(Tokens.size())); + + // Decode the batch. + auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); + if (Status == 1) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: try reducing the size of the batch "sv + "or increasing the size of context."sv) + } + if (Status < 0) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: fatal error. Please open "sv + "an issue on GitHub."sv) + } + } + + return ErrNo::Success; +} +// Generate output embedding. +void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, + const float *Embeddings) noexcept { + // Embedding vector format + // | Content | + // | ----------------------------------- | + // | '{"number_embedding": ' | + // | n_embedding | + // | ', "embedding": ' | + // | '[' | + // | n_embedding*(embedding value %.10f) | + // | (n_embedding-1)*(',') | + // | ']' | + // | '}' | + Embedding = + fmt::format(R"({{"n_embedding": {}, )" + R"("embedding": [{:.10}]}})"sv, + NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); +} +} // namespace + +// Evaluate the input tokens. Clear all inputs on success. +ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, + std::string_view LogPrefix) noexcept { + // Check if the input is set before setting up the context. + if (CxtRef.LlamaInputs.size() == 0) { + RET_ERROR(ErrNo::InvalidArgument, "{}: llama input is not set!"sv, + LogPrefix) + } + + // Get the context size. + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const uint64_t MaxTokensListSize = NCtx - 4; + // Return value. + auto ReturnCode = ErrNo::Success; + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + RET_ERROR(ErrNo::PromptTooLong, + "{}: the prompt is too long. Your input has {} tokens. "sv + "Please reduce it to {} tokens."sv, + LogPrefix, CxtRef.LlamaInputs.size(), MaxTokensListSize) + } + + // Evaluate input tokens. + ReturnCode = + evaluateTokens(Span(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.size()), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) + } + + return ErrNo::Success; +} + +// Clear the context and reset the sampler. +void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext"sv) + llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); + common_sampler_reset(CxtRef.LlamaSampler); + CxtRef.NPos = 0; + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext...Done"sv) +} + +// TODO: Merge into compute. +Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) + + const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); + // Add SEP if not present. + if (CxtRef.LlamaInputs.size() > 0 && + CxtRef.LlamaInputs.back() != llama_vocab_sep(Vocab)) { + LOG_WARN( + "getEmbedding: last token in the prompt is not SEP, "sv + "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv + "header."sv) + } + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > + GraphRef.Params.n_batch) { + RET_ERROR( + ErrNo::PromptTooLong, + "getEmbedding: the prompt is too long. Your input has {} tokens exceeds batch "sv + "size {}. Please reduce the input size or increase your batch-size."sv, + CxtRef.LlamaInputs.size(), GraphRef.Params.n_batch) + } + + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "getEmbedding"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + // Main prediction loop. + const int32_t NEmbd = llama_model_n_embd(GraphRef.LlamaModel.get()); + std::vector Embeddings(NEmbd); + + for (int I = 0; I < CxtRef.LlamaBatch.n_tokens; I++) { + if (!CxtRef.LlamaBatch.logits[I]) { + continue; + } + + // Try to get sequence embeddings. + auto *Embd = llama_get_embeddings_seq(GraphRef.LlamaContext.get(), + CxtRef.LlamaBatch.seq_id[I][0]); + if (Embd == nullptr) { + Embd = llama_get_embeddings_ith(GraphRef.LlamaContext.get(), I); + if (Embd == nullptr) { + LOG_ERROR("getEmbedding: failed to get embeddings for token {}"sv, I); + continue; + } + } + + // Normalize the embeddings. + common_embd_normalize(Embd, Embeddings.data(), NEmbd, + static_cast(CxtRef.Conf.EmbdNormalize)); + } + + std::string EmbeddingString; + buildOutputEmbedding(EmbeddingString, NEmbd, Embeddings.data()); + CxtRef.LlamaOutputs = + std::vector(EmbeddingString.begin(), EmbeddingString.end()); + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) + return ErrNo::Success; +} + +// Sample and get the output token. +ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, + bool IsSingleTokenMode) noexcept { + // Use idx = -1 to sample the next token. + const llama_token Id = common_sampler_sample( + CxtRef.LlamaSampler, GraphRef.LlamaContext.get(), /* idx */ -1); + common_sampler_accept(CxtRef.LlamaSampler, Id, /* accept_grammar */ true); + + // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); + std::string OutputString = + common_token_to_piece(GraphRef.LlamaContext.get(), Id); + CxtRef.LlamaOutputs.insert(CxtRef.LlamaOutputs.end(), OutputString.begin(), + OutputString.end()); + // In single token mode, we do not handle StreamStdout and ReversePrompt. + if (!IsSingleTokenMode) { + // When setting StreamStdout, we print the output to stdout. + if (CxtRef.Conf.StreamStdout) { + fmt::print("{}"sv, + common_token_to_piece(GraphRef.LlamaContext.get(), Id)); + std::fflush(stdout); + } + // Break if reverse prompt is found. + if (!CxtRef.Conf.ReversePrompt.empty() && + std::string(CxtRef.LlamaOutputs.begin(), CxtRef.LlamaOutputs.end()) + .find(CxtRef.Conf.ReversePrompt) != std::string::npos) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: reverse prompt found."sv) + return ErrNo::EndOfSequence; + } + } + // Deal with end of text token. + const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); + // Only stop on EOS if GraphRef.Params.sampling.ignore_eos is false. + if (!GraphRef.Params.sampling.ignore_eos && + llama_vocab_is_eog(Vocab, common_sampler_last(CxtRef.LlamaSampler))) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: EOS token found."sv) + return ErrNo::EndOfSequence; + } + // Evaluate the output token. + return evaluateTokens(Span(&Id, 1), GraphRef, + CxtRef.OutputBatch, CxtRef.NPos, true); +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/compute/inference_manager.h b/plugins/wasi_nn/GGML/compute/inference_manager.h new file mode 100644 index 00000000..b876249c --- /dev/null +++ b/plugins/wasi_nn/GGML/compute/inference_manager.h @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "GGML/core/ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +void clearContext(Graph &GraphRef, Context &CxtRef) noexcept; +Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept; +ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, + std::string_view LogPrefix) noexcept; +ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, + bool IsSingleTokenMode = false) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/ggml_core.cpp b/plugins/wasi_nn/GGML/core/ggml_core.cpp new file mode 100644 index 00000000..ce422a21 --- /dev/null +++ b/plugins/wasi_nn/GGML/core/ggml_core.cpp @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ggml_core.h" +#include "GGML/utils.h" +#include "common/types.h" +#include "host/wasi/vfs_io.h" +#include "wasinnenv.h" +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include "GGML/metadata/metadata_parser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// Llama logging callback. +void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, + void *UserData) { + Graph &GraphRef = *reinterpret_cast(UserData); + if (!GraphRef.EnableLog) { + return; + } + std::string Text(LogText); + // Remove the trailing newlines. + Text = Text.erase(Text.find_last_not_of("\n") + 1); + // Skip for "." + if (Text == ".") { + return; + } + if (LogLevel == GGML_LOG_LEVEL_ERROR) { + spdlog::error("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_WARN) { + spdlog::warn("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_INFO) { + spdlog::info("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { + spdlog::debug("[WASI-NN] llama.cpp: {}"sv, Text); + } +} +} // namespace + +Expect load(WasiNNEnvironment &Env, Span> Builders, + [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { + // Add a new graph. + EndianValue GId = Env.newGraph(Backend::GGML); + auto &GraphRef = Env.NNGraph[GId.raw()].get(); + + // Initialize the plugin parameters. + GraphRef.EnableLog = false; + GraphRef.EnableDebugLog = false; + common_params CommonParamsDefault; + CommonParamsDefault.lr.init(); + GraphRef.Params = CommonParamsDefault; + GraphRef.Params.n_keep = 0; + GraphRef.Params.n_chunks = -1; + GraphRef.Params.n_parallel = 1; + GraphRef.Params.grp_attn_n = 1; + GraphRef.Params.grp_attn_w = 512; + GraphRef.Params.n_print = -1; + GraphRef.Params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; + // Initialize the model parameters. + llama_model_params ModelParamsDefault = llama_model_default_params(); + GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; + GraphRef.Params.mmproj.path = ""sv; + GraphRef.Params.warmup = false; + + // Initialize the sampling parameters. + const common_params_sampling SamplerParamsDefault; + GraphRef.Params.sampling = SamplerParamsDefault; + // Initialize the config parameters. + GraphRef.Conf.StreamStdout = false; + GraphRef.Conf.EmbdNormalize = + static_cast(CommonParamsDefault.embd_normalize); + GraphRef.Conf.NPredict = GraphRef.Params.n_predict; + GraphRef.Conf.ReversePrompt = ""sv; + GraphRef.Conf.ImagePath = ""sv; + + // Set llama log callback. + llama_log_set(llamaLogCallback, &GraphRef); + mtmd_helper_log_set(llamaLogCallback, &GraphRef); + + // If the graph builder length is greater than 1, builder[1] contains the + // metadata. + if (Builders.size() > 1) { + const std::string Metadata(reinterpret_cast(Builders[1].data()), + Builders[1].size()); + // Ignore context or model updates when initializing the graph. + auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); + if (Res != ErrNo::Success) { + Env.deleteGraph(GId.raw()); + RET_ERROR(Res, "load: Failed to parse metadata."sv) + } + } + + // Logging. + LOG_DEBUG(GraphRef.EnableDebugLog, "load"sv) + LOG_INFO(GraphRef.EnableLog, "LLAMA_COMMIT {}"sv, LLAMA_COMMIT) + LOG_INFO(GraphRef.EnableLog, "LLAMA_BUILD_NUMBER {}"sv, LLAMA_BUILD_NUMBER) + + // Handle the model path. + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path."sv) + auto Weight = Builders[0]; + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + if (BinModel.substr(0, 8) == "preload:"sv) { + GraphRef.Params.model.path = BinModel.substr(8); + } else { + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Model path not found in nn-preload, write model into "sv + "a tmpfile."sv) + // TODO: pass the model directly to ggml. + // Write ggml model to file. + GraphRef.Params.model.path = "ggml-model.bin"sv; + WasmEdge::FStream::OFStream TempFile( + GraphRef.Params.model.path, std::ios_base::out | std::ios_base::binary, + Env.getEnv()); + if (!TempFile) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, + "load: Failed to create the temporary file. Currently, our "sv + "workaround involves creating a temporary model file named "sv + "\"ggml-model.bin\" and passing this filename as a "sv + "parameter to the ggml llama library."sv) + } + TempFile.write(BinModel.data(), BinModel.size()); + TempFile.close(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Write model into a tmpfile...Done"sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path...Done"sv) + + // Check if the model exists. + if (!std::filesystem::exists( + std::filesystem::u8path(GraphRef.Params.model.path))) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::ModelNotFound, "load: model file not found."sv) + } + GraphRef.Params.model = GraphRef.Params.model; + + // Initialize ggml parameters. + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize ggml model with given parameters."sv) + + common_params Params = GraphRef.Params; + Params.cpuparams.n_threads = + static_cast(GraphRef.Params.cpuparams.n_threads); + Params.cpuparams_batch.n_threads = + static_cast(GraphRef.Params.cpuparams.n_threads); + llama_backend_init(); + llama_numa_init(Params.numa); + + // Initialize the llama model and context. + llama_model_params ModelParams = common_model_params_to_llama(Params); + GraphRef.LlamaModel = llama_model_ptr( + llama_model_load_from_file(Params.model.path.c_str(), ModelParams)); + if (GraphRef.LlamaModel == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init model."sv) + } + GraphRef.LlamaContext = llama_context_ptr(llama_init_from_model( + GraphRef.LlamaModel.get(), common_context_params_to_llama(Params))); + if (GraphRef.LlamaContext == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init context."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize ggml model with given parameters...Done"sv) + + // Initialize the TTS related model and context. + if (GraphRef.TextToSpeech) { + LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model."sv) + Params.model = GraphRef.Params.vocoder.model; + Params.embedding = true; + llama_model_params TTSModelParams = common_model_params_to_llama(Params); + GraphRef.TTSModel = llama_model_ptr( + llama_model_load_from_file(Params.model.path.c_str(), TTSModelParams)); + if (GraphRef.TTSModel == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS model."sv) + } + GraphRef.TTSContext = llama_context_ptr(llama_init_from_model( + GraphRef.TTSModel.get(), common_context_params_to_llama(Params))); + if (GraphRef.TTSContext == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS context."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model...Done"sv) + } + + // Store the loaded graph. + GraphId = GId.le(); + Env.NNGraph[GId.raw()].setReady(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx"sv) + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, + llama_print_system_info()) + + auto &CxtRef = Env.NNContext[ContextId].get(); + // Allocate the batch for input string prompt tokens. + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + + // Allocate the batch for output sampling. The batch size is always 1. + CxtRef.OutputBatch = allocBatch(1); + + // Allocate sampler. + CxtRef.LlamaSampler = + common_sampler_init(GraphRef.LlamaModel.get(), GraphRef.Params.sampling); + + Env.NNContext[ContextId].setReady(); + ContextId = EndianValue(ContextId).le(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) + return ErrNo::Success; +} + +Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle"sv) + + // Logging for the llama timings. + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); + } + + // Clear the outputs. + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens"sv) + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens...Done"sv) + + // Reset the llama sampler. + common_sampler_reset(CxtRef.LlamaSampler); + CxtRef.ComputeSingleStarted = false; + CxtRef.NPos = 0; + + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle...Done"sv) + return ErrNo::Success; +} + +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + const bool IsDebugLog = GraphRef.EnableDebugLog; + LOG_DEBUG(IsDebugLog, "unload"sv) + + // TODO: Move the resource deallocation into the destructor. + if (GraphRef.LlamaModel != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free llama model"sv) + GraphRef.LlamaModel.reset(); + LOG_DEBUG(IsDebugLog, "unload: free llama model...Done"sv) + } + if (GraphRef.LlamaContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free llama context"sv) + GraphRef.LlamaContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free llama context...Done"sv) + } + if (GraphRef.VisionContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free mtmd context"sv) + GraphRef.VisionContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free mtmd context...Done"sv) + } + if (GraphRef.VisionInputChunks != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks"sv) + GraphRef.VisionInputChunks.reset(); + LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks...Done"sv) + } + if (GraphRef.TTSModel != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free TTS model"sv) + GraphRef.TTSModel.reset(); + LOG_DEBUG(IsDebugLog, "unload: free TTS model...Done"sv) + } + if (GraphRef.TTSContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free TTS context"sv) + GraphRef.TTSContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free TTS context...Done"sv) + } + if (!GraphRef.TensorBuftOverrides.empty()) { + LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides"sv) + GraphRef.TensorBuftOverrides.clear(); + LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides...Done"sv) + } + Env.deleteGraph(GraphId); + Env.mdRemoveById(GraphId); + + LOG_DEBUG(IsDebugLog, "unload...Done"sv) + return ErrNo::Success; +} + +Expect finalizeExecCtx(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context"sv) + + if (CxtRef.LlamaSampler != nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "finalize_execution_context: free compute_single sampler"sv) + common_sampler_free(CxtRef.LlamaSampler); + CxtRef.LlamaSampler = nullptr; + LOG_DEBUG( + GraphRef.EnableDebugLog, + "finalize_execution_context: free compute_single sampler...Done"sv) + } + llama_batch_free(CxtRef.LlamaBatch); + llama_batch_free(CxtRef.OutputBatch); + Env.deleteContext(ContextId); + + LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context...Done"sv) + return ErrNo::Success; +} + +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] ggml backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"ggml\" to build it."sv); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect getOutputSingle(WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect computeSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finiSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/ggml_core.h b/plugins/wasi_nn/GGML/core/ggml_core.h new file mode 100644 index 00000000..68671572 --- /dev/null +++ b/plugins/wasi_nn/GGML/core/ggml_core.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ggml_type.h" +#include "plugin/plugin.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} // namespace WasmEdge::Host::WASINN +namespace WasmEdge::Host::WASINN::GGML { + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect finiSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect getOutputSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect computeSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/ggml_type.h b/plugins/wasi_nn/GGML/core/ggml_type.h new file mode 100644 index 00000000..cda10b0e --- /dev/null +++ b/plugins/wasi_nn/GGML/core/ggml_type.h @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "wasinntypes.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include "wasinntypes.h" +#include +#include +#include +#include +#include +#include + +#endif + +namespace WasmEdge::Host::WASINN::GGML { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +enum class EmbdNormalizeType : int32_t { + // Follow: + // https://github.com/ggerganov/llama.cpp/blob/0bf2d10c5514ff61b99897a4a5054f846e384e1e/common/common.h#L312 + None = -1, + MaxAbsolute = 0, + Taxicab = 1, + Euclidean = 2, + PNorm = 3, +}; + +struct LocalConfig { + // Configuration values that can be changed in every context. + // The graph holds a default config parsed from metadata during loading. + // The context inherits a copy from the graph during creation and can be + // modified when parsing metadata in set_input. + bool StreamStdout = false; + EmbdNormalizeType EmbdNormalize = EmbdNormalizeType::Euclidean; + int64_t NPredict; + std::string ReversePrompt; + std::string ImagePath; + bool AlwaysRegenerateImageEmbd = false; +}; + +struct Graph { + // Plugin parameters: + bool EnableLog = false; + bool EnableDebugLog = false; + common_params Params; + std::list TensorBuftOverrides; + // Model context: + llama_model_ptr LlamaModel = nullptr; + llama_context_ptr LlamaContext = nullptr; + // Multimodal context: + mtmd::context_ptr VisionContext = nullptr; + mtmd::input_chunks_ptr VisionInputChunks = nullptr; + // Text-to-speech: + bool TextToSpeech = false; + std::string TTSOutputFilePath = "output.wav"; + std::string TTSSpeakerFilePath; + llama_model_ptr TTSModel = nullptr; + llama_context_ptr TTSContext = nullptr; + // Configs. + LocalConfig Conf; +}; + +struct Context { +public: + Context(uint32_t GId, Graph &G) noexcept : GraphId(GId), Conf(G.Conf) {} + uint32_t GraphId; + // Llama inputs: + std::vector LlamaInputs; + uint64_t LlamaNInputs = 0; + // Llama outputs: + std::vector LlamaOutputs; + std::vector LlamaOutputTokens; + // Data for computation: + bool ComputeSingleStarted = false; + struct common_sampler *LlamaSampler = nullptr; + // Handle the batch in the context to prevent reallocation during every + // computation. + struct llama_batch LlamaBatch; + struct llama_batch OutputBatch; + int64_t CurrentBatchSize = 0; + size_t ImagePosition = 0; + int32_t NPos = 0; + // Configs: + LocalConfig Conf; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// Macro for logging debug message. +#define LOG_DEBUG(Debug, ...) \ + if (Debug) { \ + spdlog::info("[WASI-NN][Debug] GGML backend: "sv __VA_ARGS__); \ + } + +// Macro for logging info message. +#define LOG_INFO(Info, ...) \ + if (Info) { \ + spdlog::info("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ + } + +// Macro for logging warning message. +#define LOG_WARN(...) spdlog::warn("[WASI-NN] GGML backend: "sv __VA_ARGS__); + +// Macro for logging error message. +#define LOG_ERROR(...) spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); + +// Macro for logging an error message and returning. +#define RET_ERROR(Error, ...) \ + spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ + return Error; +} // namespace +#endif + +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/input_processor.cpp b/plugins/wasi_nn/GGML/core/input_processor.cpp new file mode 100644 index 00000000..12f57de1 --- /dev/null +++ b/plugins/wasi_nn/GGML/core/input_processor.cpp @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "GGML/metadata/metadata_parser.h" +#include "GGML/tts/tts_core.h" +#include "GGML/utils.h" +#include "ggml_core.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput"sv) + + // Use index 1 for metadata. + if (Index == 1) { + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing"sv) + bool IsModelParamsUpdated = false; + bool IsContextParamsUpdated = false; + bool IsSamplerParamsUpdated = false; + const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + auto Res = + parseMetadata(GraphRef, CxtRef.Conf, Metadata, &IsModelParamsUpdated, + &IsContextParamsUpdated, &IsSamplerParamsUpdated); + if (Res != ErrNo::Success) { + RET_ERROR(Res, "setInput: failed to parse metadata."sv) + } + +#ifndef __APPLE__ + // XXX: Because of the limitation in the WASI-NN proposal, this is a + // workaround for non-macOS devices. However, if the model params are + // updated in the configuration stage, we do not recommend using this to + // avoid reloading the model. + { + if (IsModelParamsUpdated || GraphRef.LlamaModel == nullptr) { + // The llama model may be nullptr if set_input updated the model params + // last time. Therefore, in addition to updated model params, we should + // reload the llama model if the model is nullptr. + LOG_INFO(GraphRef.EnableLog, + "setInput: Reload model due to parameters change."sv) + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = + static_cast(GraphRef.Params.n_gpu_layers); + GraphRef.LlamaModel.reset(); + // Due to the model change, the context and sampler should also be + // reloaded. The new context and sampler will be created in the next + // block. + GraphRef.LlamaContext.reset(); + if (CxtRef.LlamaSampler) { + // TODO: Trigger the sampler in other contexts to reallocate. + common_sampler_free(CxtRef.LlamaSampler); + CxtRef.LlamaSampler = nullptr; + } + GraphRef.LlamaModel = llama_model_ptr(llama_model_load_from_file( + GraphRef.Params.model.path.c_str(), ModelParams)); + if (GraphRef.LlamaModel == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) + } + } + } +#endif + + // Some changes to context parameters will require the context to be + // reloaded. + if (IsContextParamsUpdated || GraphRef.LlamaContext == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Reload llama context due to parameters change."sv) + GraphRef.LlamaContext.reset(); + GraphRef.LlamaContext = llama_context_ptr(llama_init_from_model( + GraphRef.LlamaModel.get(), + common_context_params_to_llama(GraphRef.Params))); + if (GraphRef.LlamaContext == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init context."sv) + } + } + + // Some changes to sampling parameters will require the sampler to be + // reallocated. + if (IsSamplerParamsUpdated || CxtRef.LlamaSampler == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Reallocate llama sampler due to parameters change."sv) + if (CxtRef.LlamaSampler) { + common_sampler_free(CxtRef.LlamaSampler); + } + CxtRef.LlamaSampler = common_sampler_init(GraphRef.LlamaModel.get(), + GraphRef.Params.sampling); + if (GraphRef.LlamaContext == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init sampler."sv) + } + } + + // Check whether the batch size changed. + if (CxtRef.CurrentBatchSize != GraphRef.Params.n_batch) { + llama_batch_free(CxtRef.LlamaBatch); + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + } + + Env.NNGraph[CxtRef.GraphId].setReady(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: found Metadata, processing...Done"sv) + return ErrNo::Success; + } + + // Check that the graph is valid after reloading during the previous + // set_input. + if (!Env.NNGraph[CxtRef.GraphId].isReady()) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: Graph is invalid. Please reload again by passing metadata "sv + "in set_input or unload graph."sv) + } + + // Clear the llama context. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context"sv) + llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context...Done"sv) + + // Set the input. + const bool AddSpecial = true; + const bool ParseSpecial = true; + std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + CxtRef.LlamaInputs.clear(); + + auto Base64ImagePos = findBase64ImagePayload(Prompt); + + if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { + // First check whether the projection model is provided. + if (GraphRef.Params.mmproj.path == ""sv) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: the given model does not support image input, so a projection model is required."sv) + } + + // Make sure the projection model is loaded. + if (GraphRef.VisionContext == nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: initialize mtmd context."sv) + // Initialize the mtmd context. + mtmd_context_params VisionContextParams = mtmd_context_params_default(); + std::string VisionPromptImagePlaceholderStr(VisionPromptImagePlaceholder); + VisionContextParams.media_marker = + VisionPromptImagePlaceholderStr.c_str(); + VisionContextParams.use_gpu = GraphRef.Params.mmproj_use_gpu; + VisionContextParams.n_threads = GraphRef.Params.cpuparams.n_threads; + VisionContextParams.print_timings = + GraphRef.EnableLog || GraphRef.EnableDebugLog; + GraphRef.VisionContext.reset( + mtmd_init_from_file(GraphRef.Params.mmproj.path.c_str(), + GraphRef.LlamaModel.get(), VisionContextParams)); + if (GraphRef.VisionContext == nullptr) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: unable to load the mmproj model {}."sv, + GraphRef.Params.mmproj.path) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: initialize mtmd context...Done"sv) + } + + // Show some warnings for context size. + if (GraphRef.Params.n_ctx < 4096) { + LOG_INFO( + GraphRef.EnableLog, + "setInput: Context size is {}, we recommend context size >= 4096 when using multimodal models for better results"sv, + GraphRef.Params.n_ctx) + } + + // Get the image bitmaps. + // Follow this link for the supported image formats: + // https://github.com/ggml-org/llama.cpp/blob/master/common/stb_image.h + mtmd::bitmaps Bitmaps; + if (GraphRef.VisionContext != nullptr) { + if (Base64ImagePos.has_value()) { + // Load the image bitmap from the base64 image. + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from the base64 image."sv) + // Extract the payload and image type from the prompt. + std::optional, std::string>> Payload = + extractBase64ImagePayload(Prompt, *Base64ImagePos, + VisionPromptImagePlaceholder); + if (Payload.has_value()) { + // Create the new image bitmap. + mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_buf( + GraphRef.VisionContext.get(), Payload->first.data(), + Payload->first.size())); + if (Bitmap.ptr == nullptr) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: unable to load the image from base64 paylaod."sv) + } + Bitmaps.entries.push_back(std::move(Bitmap)); + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: Compute image embd from the base64 image...Done"sv) + } else { + // Load the image from the file. + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from file: {}"sv, + CxtRef.Conf.ImagePath) + mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_file( + GraphRef.VisionContext.get(), CxtRef.Conf.ImagePath.c_str())); + if (Bitmap.ptr == nullptr) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: unable to load the image bitmap from file: {}."sv, + CxtRef.Conf.ImagePath) + } + Bitmaps.entries.push_back(std::move(Bitmap)); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from file: {}...Done"sv, + CxtRef.Conf.ImagePath) + } + } + + // Tokenize the prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize the mtmd prompt"sv) + GraphRef.VisionInputChunks.reset(mtmd_input_chunks_init()); + mtmd_input_text MtmdText; + MtmdText.text = Prompt.c_str(); + MtmdText.add_special = AddSpecial; + MtmdText.parse_special = ParseSpecial; + std::vector BitmapsPtr = Bitmaps.c_ptr(); + int32_t Res = mtmd_tokenize(GraphRef.VisionContext.get(), + GraphRef.VisionInputChunks.get(), &MtmdText, + BitmapsPtr.data(), BitmapsPtr.size()); + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: unable to tokenize the mtmd prompt."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: tokenize the mtmd prompt...Done"sv) + + // Get the number of input tokens for the metadata. + CxtRef.LlamaNInputs = 0; + for (size_t ChunkIndex = 0; + ChunkIndex < mtmd_input_chunks_size(GraphRef.VisionInputChunks.get()); + ++ChunkIndex) { + size_t NTokens = 0; + const mtmd_input_chunk *Chunk = + mtmd_input_chunks_get(GraphRef.VisionInputChunks.get(), ChunkIndex); + mtmd_input_chunk_get_tokens_text(Chunk, &NTokens); + CxtRef.LlamaNInputs += NTokens; + } + } else if (GraphRef.TextToSpeech == true) { + // TTS prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt"sv) + CxtRef.LlamaInputs = processTTSPrompt(Env, GraphRef, Prompt); + if (CxtRef.LlamaInputs.empty()) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: failed to tokenize tts prompt."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt...Done"sv) + + // Get the number of input tokens for the metadata. + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + } else { + // Text only prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) + CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext.get(), Prompt, + AddSpecial, ParseSpecial); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: tokenize text prompt...Done"sv) + + // Get the number of input tokens for the metadata. + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + } + + // The context may currently be in compute_single mode. Reset the compute + // state. + CxtRef.ComputeSingleStarted = false; + + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) + return ErrNo::Success; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/output_generator.cpp b/plugins/wasi_nn/GGML/core/output_generator.cpp new file mode 100644 index 00000000..f5c5dda2 --- /dev/null +++ b/plugins/wasi_nn/GGML/core/output_generator.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { +// Generate output metadata. +std::string buildOutputMetadata(Context &CxtRef) noexcept { + return fmt::format(R"({{"input_tokens": {}, )" + R"("output_tokens": {}, )" + R"("llama_build_number": {}, )" + R"("llama_commit": "{}"}})"sv, + CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), + LLAMA_BUILD_NUMBER, LLAMA_COMMIT); +} +} // namespace + +Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}"sv, Index) + + // Use index 1 for output metadata. + if (Index == 1) { + std::string Metadata = buildOutputMetadata(CxtRef); + std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutputSingle: with Index {} a.k.a Metadata...Done"sv, Index) + return ErrNo::Success; + } + + std::string LastToken = common_token_to_piece( + GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); + std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); + BytesWritten = EndianValue(static_cast(LastToken.length())).le(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}...Done"sv, + Index) + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}"sv, Index) + + // Use index 1 for output metadata. + if (Index == 1) { + std::string Metadata = buildOutputMetadata(CxtRef); + std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutput: with Index {} a.k.a Metadata ...Done"sv, Index) + return ErrNo::Success; + } + + std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.size(), + OutBuffer.data()); + BytesWritten = + EndianValue(static_cast(CxtRef.LlamaOutputs.size())).le(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}...Done"sv, Index) + return ErrNo::Success; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp new file mode 100644 index 00000000..fc98904b --- /dev/null +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp @@ -0,0 +1,565 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "metadata_parser.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +// Parse metadata from JSON. +ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, + const std::string &Metadata, bool *IsModelUpdated, + bool *IsContextUpdated, bool *IsSamplerUpdated) noexcept { + // Parse metadata from the json. + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) + } + + // Get the current llama parameters. + int64_t PrevNGPULayers = GraphRef.Params.n_gpu_layers; + bool PrevEmbedding = GraphRef.Params.embedding; + // Get the current sampler parameters. + double PrevTemp = GraphRef.Params.sampling.temp; + double PrevTopP = GraphRef.Params.sampling.top_p; + double PrevRepeatPenalty = GraphRef.Params.sampling.penalty_repeat; + double PrevPresencePenalty = GraphRef.Params.sampling.penalty_present; + double PrevFrequencyPenalty = GraphRef.Params.sampling.penalty_freq; + std::string PrevGrammar = + common_grammar_value(GraphRef.Params.sampling.grammar); + uint32_t PrevSeed = GraphRef.Params.sampling.seed; + + try { + parseJsonAuto(Doc, "enable-log", GraphRef.EnableLog); + parseJsonAuto(Doc, "enable-debug-log", GraphRef.EnableDebugLog); + + parseJsonWithCastAuto(Doc, "main-gpu", GraphRef.Params.main_gpu); + parseJsonWithCastAuto(Doc, "n-gpu-layers", + GraphRef.Params.n_gpu_layers); + + parseJsonWithProcessorAuto(Doc, "cpu-moe", + [&GraphRef](const bool &CpuMoe) -> bool { + if (CpuMoe) { + GraphRef.TensorBuftOverrides.push_back( + "\\.ffn_(up|down|gate)_exps"); + } + return true; + }); + + parseJsonWithProcessorAuto( + Doc, "n-cpu-moe", [&GraphRef](const int64_t &NCpuMoe) -> bool { + if (NCpuMoe < 0) { + spdlog::error("[WASI-NN] GGML backend: Invalid n-cpu-moe value."); + return false; + } + for (int I = 0; I < NCpuMoe; I++) { + GraphRef.TensorBuftOverrides.push_back( + string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", I)); + } + return true; + }); + parseJsonWithProcessorAuto( + Doc, "tensor-split", [&GraphRef](const std::string_view &TSV) -> bool { + // The TensorSplit is a comma-separated list of non-negative values. + // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. + std::string TS(TSV); + std::replace(TS.begin(), TS.end(), ',', ' '); + std::stringstream SS(TS); + std::memset(GraphRef.Params.tensor_split, 0, + sizeof(GraphRef.Params.tensor_split)); + uint32_t TensorSplitSize = 0; + while (SS.good()) { + float TmpTensor; + SS >> TmpTensor; + GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; + } + size_t NDevices = llama_max_devices(); + if (TensorSplitSize > NDevices) { + spdlog::error( + "[WASI-NN] GGML backend: Number of Tensor-Split is larger than " + "MaxDevices, please reduce the size of tensor-split."); + return false; + } + for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { + GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; + } + return true; + }); + parseJsonAuto(Doc, "embedding", GraphRef.Params.embedding); + parseJsonWithProcessorAuto( + Doc, "split-mode", + [&GraphRef](const std::string_view &SplitMode) -> bool { + if (SplitMode == "none"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (SplitMode == "layer"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (SplitMode == "row"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + spdlog::error("[WASI-NN] GGML backend: Unknown split-mode: " + "{}. Valid: none, layer, row.", + SplitMode); + return false; + } + return true; + }); + parseJsonWithCastAuto(Doc, "mmproj", + GraphRef.Params.mmproj.path); + // The TTS parameters using macros + parseJsonAuto(Doc, "tts", GraphRef.TextToSpeech); + parseJsonWithCastAuto(Doc, "model-vocoder", + GraphRef.Params.vocoder.model.path); + parseJsonWithCastAuto(Doc, "tts-output-file", + GraphRef.TTSOutputFilePath); + + parseJsonWithCastAuto(Doc, "tts-speaker-file", + GraphRef.TTSSpeakerFilePath); + + // The context parameters + parseJsonWithCastAuto(Doc, "ctx-size", GraphRef.Params.n_ctx); + parseJsonWithCastAuto(Doc, "batch-size", GraphRef.Params.n_batch); + parseJsonWithCastAuto(Doc, "ubatch-size", + GraphRef.Params.n_ubatch); + parseJsonWithCastAuto(Doc, "n-keep", GraphRef.Params.n_keep); + parseJsonWithCastAuto(Doc, "n-chunks", GraphRef.Params.n_chunks); + parseJsonWithCastAuto(Doc, "n-parallel", + GraphRef.Params.n_parallel); + parseJsonWithCastAuto(Doc, "n-sequences", + GraphRef.Params.n_sequences); + parseJsonWithCastAuto(Doc, "grp-attn-n", + GraphRef.Params.grp_attn_n); + parseJsonWithCastAuto(Doc, "grp-attn-w", + GraphRef.Params.grp_attn_w); + parseJsonWithCastAuto(Doc, "n-print", GraphRef.Params.n_print); + parseJsonWithCastAuto(Doc, "rope-freq-base", + GraphRef.Params.rope_freq_base); + parseJsonWithCastAuto(Doc, "rope-freq-scale", + GraphRef.Params.rope_freq_scale); + parseJsonWithCastAuto(Doc, "yarn-ext-factor", + GraphRef.Params.yarn_ext_factor); + parseJsonWithCastAuto(Doc, "yarn-attn-factor", + GraphRef.Params.yarn_attn_factor); + parseJsonWithCastAuto(Doc, "yarn-beta-fast", + GraphRef.Params.yarn_beta_fast); + parseJsonWithCastAuto(Doc, "yarn-beta-slow", + GraphRef.Params.yarn_beta_slow); + parseJsonWithCastAuto(Doc, "yarn-orig-ctx", + GraphRef.Params.yarn_orig_ctx); + parseJsonAuto(Doc, "mask-valid", + GraphRef.Params.cpuparams.mask_valid); + parseJsonWithCastAuto(Doc, "priority", + GraphRef.Params.cpuparams.priority); + parseJsonAuto(Doc, "strict-cpu", + GraphRef.Params.cpuparams.strict_cpu); + parseJsonWithCastAuto(Doc, "poll", GraphRef.Params.cpuparams.poll); + parseJsonAuto(Doc, "mask-valid-batch", + GraphRef.Params.cpuparams_batch.mask_valid); + parseJsonWithProcessorAuto( + Doc, "priority-batch", [&GraphRef](const int64_t &Priority) -> bool { + GraphRef.Params.cpuparams_batch.priority = + static_cast(Priority); + return true; + }); + parseJsonAuto(Doc, "strict-cpu-batch", + GraphRef.Params.cpuparams_batch.strict_cpu); + parseJsonWithCastAuto(Doc, "poll-batch", + GraphRef.Params.cpuparams_batch.poll); + + parseJsonWithCastAuto(Doc, "numa", GraphRef.Params.numa); + + parseJsonWithCastAuto(Doc, "rope-scaling-type", + GraphRef.Params.rope_scaling_type); + parseJsonWithCastAuto(Doc, "pooling-type", + GraphRef.Params.pooling_type); + parseJsonWithCastAuto(Doc, "attention-type", + GraphRef.Params.attention_type); + parseJsonWithProcessorAuto( + Doc, "threads", [&GraphRef](const int64_t &NThreads) -> bool { + GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); + return true; + }); + parseJsonWithCastAuto(Doc, "threads-batch", + GraphRef.Params.cpuparams_batch.n_threads); + parseJsonWithCastAuto(Doc, "n-prev", + GraphRef.Params.sampling.n_prev); + parseJsonWithCastAuto(Doc, "n-probs", + GraphRef.Params.sampling.n_probs); + parseJsonWithCastAuto(Doc, "min-keep", + GraphRef.Params.sampling.min_keep); + parseJsonWithCastAuto(Doc, "top-k", + GraphRef.Params.sampling.top_k); + parseJsonWithCastAuto(Doc, "min-p", GraphRef.Params.sampling.min_p); + parseJsonWithCastAuto(Doc, "xtc-probability", + GraphRef.Params.sampling.xtc_probability); + parseJsonWithCastAuto(Doc, "xtc-threshold", + GraphRef.Params.sampling.xtc_threshold); + parseJsonWithCastAuto(Doc, "typ-p", GraphRef.Params.sampling.typ_p); + parseJsonWithCastAuto(Doc, "dynatemp-range", + GraphRef.Params.sampling.dynatemp_range); + parseJsonWithCastAuto(Doc, "dynatemp-exponent", + GraphRef.Params.sampling.dynatemp_exponent); + parseJsonWithCastAuto(Doc, "last-n-penalty", + GraphRef.Params.sampling.penalty_last_n); + parseJsonWithProcessorAuto( + Doc, "temp", [&GraphRef](const double &Temp) -> bool { + GraphRef.Params.sampling.temp = + static_cast(std::max(0.0, Temp)); + return true; + }); + + parseJsonWithProcessorAuto( + Doc, "top-p", [&GraphRef](const double &TopP) -> bool { + GraphRef.Params.sampling.top_p = + static_cast(std::max(0.0, TopP)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "repeat-penalty", + [&GraphRef](const double &RepeatPenalty) -> bool { + GraphRef.Params.sampling.penalty_repeat = + static_cast(std::max(0.0, RepeatPenalty)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "presence-penalty", + [&GraphRef](const double &PresencePenalty) -> bool { + GraphRef.Params.sampling.penalty_present = + static_cast(std::max(0.0, PresencePenalty)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "frequency-penalty", + [&GraphRef](const double &FrequencyPenalty) -> bool { + GraphRef.Params.sampling.penalty_freq = + static_cast(std::max(0.0, FrequencyPenalty)); + return true; + }); + parseJsonWithCastAuto(Doc, "dry-multipier", + GraphRef.Params.sampling.dry_multiplier); + parseJsonWithCastAuto(Doc, "dry-base", + GraphRef.Params.sampling.dry_base); + parseJsonWithCastAuto(Doc, "dry-allowed-length", + GraphRef.Params.sampling.dry_allowed_length); + parseJsonWithCastAuto(Doc, "dry-last-n-penalty", + GraphRef.Params.sampling.penalty_last_n); + parseJsonWithCastAuto(Doc, "mirostat", + GraphRef.Params.sampling.mirostat); + parseJsonWithCastAuto(Doc, "mirostat-eta", + GraphRef.Params.sampling.mirostat_eta); + parseJsonAuto(Doc, "ignore-eos", GraphRef.Params.sampling.ignore_eos); + parseJsonAuto(Doc, "no-perf-sampling", + GraphRef.Params.sampling.no_perf); + parseJsonAuto(Doc, "timing-per-token", + GraphRef.Params.sampling.timing_per_token); + + parseJsonWithProcessorAuto( + Doc, "grammar", [&GraphRef](const std::string_view &Grammar) -> bool { + GraphRef.Params.sampling.grammar = + common_grammar(COMMON_GRAMMAR_TYPE_USER, std::string(Grammar)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "json-schema", + [&GraphRef](const std::string_view &JsonSchema) -> bool { + GraphRef.Params.sampling.grammar = + common_grammar(COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, + json_schema_to_grammar( + nlohmann::ordered_json::parse(JsonSchema))); + return true; + }); + parseJsonWithCastAuto(Doc, "seed", GraphRef.Params.sampling.seed); + + // The speculative parameters. + parseJsonWithCastAuto(Doc, "n-ctx-speculative", + GraphRef.Params.speculative.n_ctx); + parseJsonWithCastAuto(Doc, "n-max-speculative", + GraphRef.Params.speculative.n_max); + parseJsonWithCastAuto(Doc, "n-min-speculative", + GraphRef.Params.speculative.n_min); + parseJsonWithCastAuto(Doc, "n-gpu-layers-speculative", + GraphRef.Params.speculative.n_gpu_layers); + parseJsonWithCastAuto(Doc, "p-split-speculative", + GraphRef.Params.speculative.p_split); + parseJsonWithCastAuto(Doc, "p-min-speculative", + GraphRef.Params.speculative.p_min); + // The vocoder parameters. + parseJsonWithCastAuto( + Doc, "hf-repo-vocoder", GraphRef.Params.vocoder.model.hf_repo); + parseJsonWithCastAuto( + Doc, "hf-file-vocoder", GraphRef.Params.vocoder.model.hf_file); + parseJsonWithCastAuto(Doc, "model-url-vocoder", + GraphRef.Params.vocoder.model.url); + + // The config parameters. + parseJsonAuto(Doc, "stream-stdout", ConfRef.StreamStdout); + parseJsonAuto(Doc, "n-predict", ConfRef.NPredict); + parseJsonWithCastAuto(Doc, "reverse-prompt", + ConfRef.ReversePrompt); + parseJsonWithCastAuto(Doc, "image", ConfRef.ImagePath); + parseJsonAuto(Doc, "always-regenerate-image-embd", + ConfRef.AlwaysRegenerateImageEmbd); + parseJsonWithProcessorAuto( + Doc, "model-alias", + [&GraphRef](const std::string_view &ModelAlias) -> bool { + GraphRef.Params.model_alias.emplace(ModelAlias); + return true; + }); + parseJsonWithCastAuto(Doc, "model-url", + GraphRef.Params.model.url); + parseJsonWithCastAuto(Doc, "hf-token", + GraphRef.Params.hf_token); + parseJsonWithCastAuto(Doc, "hf-repo", + GraphRef.Params.model.hf_repo); + parseJsonWithCastAuto(Doc, "hf-file", + GraphRef.Params.model.hf_file); + parseJsonWithCastAuto(Doc, "prompt-file", + GraphRef.Params.prompt_file); + parseJsonWithCastAuto(Doc, "path-prompt-cache", + GraphRef.Params.path_prompt_cache); + parseJsonWithCastAuto(Doc, "input-prefix", + GraphRef.Params.input_prefix); + parseJsonWithCastAuto(Doc, "input-suffix", + GraphRef.Params.input_suffix); + parseJsonWithCastAuto( + Doc, "lookup-cache-static", + GraphRef.Params.speculative.lookup_cache_static); + parseJsonWithCastAuto( + Doc, "lookup-cache-dynamic", + GraphRef.Params.speculative.lookup_cache_dynamic); + parseJsonWithCastAuto(Doc, "logits-file", + GraphRef.Params.logits_file); + parseJsonAuto(Doc, "lora-init-without-apply", + GraphRef.Params.lora_init_without_apply); + parseJsonWithCastAuto(Doc, "verbosity", GraphRef.Params.verbosity); + parseJsonWithCastAuto(Doc, "control-vector-layer-start", + GraphRef.Params.control_vector_layer_start); + parseJsonWithCastAuto(Doc, "control-vector-layer-end", + GraphRef.Params.control_vector_layer_end); + parseJsonWithCastAuto(Doc, "ppl-stride", + GraphRef.Params.ppl_stride); + parseJsonWithCastAuto(Doc, "ppl-output-type", + GraphRef.Params.ppl_output_type); + parseJsonAuto(Doc, "hellaswag", GraphRef.Params.hellaswag); + parseJsonWithCastAuto(Doc, "hellaswag-tasks", + GraphRef.Params.hellaswag_tasks); + parseJsonAuto(Doc, "winogrande", GraphRef.Params.winogrande); + parseJsonWithCastAuto(Doc, "winogrande-tasks", + GraphRef.Params.winogrande_tasks); + parseJsonAuto(Doc, "multiple-choice", + GraphRef.Params.multiple_choice); + parseJsonWithCastAuto( + Doc, "multiple-choice-tasks", GraphRef.Params.multiple_choice_tasks); + parseJsonAuto(Doc, "kl-divergence", GraphRef.Params.kl_divergence); + parseJsonAuto(Doc, "usage", GraphRef.Params.usage); + parseJsonAuto(Doc, "use-color", GraphRef.Params.use_color); + parseJsonAuto(Doc, "special", GraphRef.Params.special); + parseJsonAuto(Doc, "interactive", GraphRef.Params.interactive); + parseJsonAuto(Doc, "interactive-first", + GraphRef.Params.interactive_first); + parseJsonAuto(Doc, "prompt-cache-all", + GraphRef.Params.prompt_cache_all); + parseJsonAuto(Doc, "prompt-cache-ro", + GraphRef.Params.prompt_cache_ro); + parseJsonAuto(Doc, "escape", GraphRef.Params.escape); + parseJsonAuto(Doc, "multiline-input", + GraphRef.Params.multiline_input); + parseJsonAuto(Doc, "simple-io", GraphRef.Params.simple_io); + parseJsonAuto(Doc, "cont-batching", GraphRef.Params.cont_batching); + parseJsonWithProcessorAuto( + Doc, "flash-attn", + [&GraphRef](const std::string_view &FlashAttn) -> bool { + if (FlashAttn == "on"sv || FlashAttn == "enabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } else if (FlashAttn == "off"sv || FlashAttn == "disabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } else if (FlashAttn == "auto"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; + } else { + spdlog::error( + "[WASI-NN] GGML backend: The flash-attn option must be " + "one of: on, off, auto."); + return false; + } + return true; + }); + parseJsonAuto(Doc, "no-perf", GraphRef.Params.no_perf); + parseJsonAuto(Doc, "ctx-shift", GraphRef.Params.ctx_shift); + parseJsonAuto(Doc, "input-prefix-bos", + GraphRef.Params.input_prefix_bos); + parseJsonAuto(Doc, "use-mlock", GraphRef.Params.use_mlock); + parseJsonAuto(Doc, "use-mmap", GraphRef.Params.use_mmap); + parseJsonAuto(Doc, "verbose-prompt", GraphRef.Params.verbose_prompt); + parseJsonAuto(Doc, "display-prompt", GraphRef.Params.display_prompt); + parseJsonAuto(Doc, "no-kv-offload", GraphRef.Params.no_kv_offload); + parseJsonAuto(Doc, "warmup", GraphRef.Params.warmup); + parseJsonAuto(Doc, "check-tensors", GraphRef.Params.check_tensors); + parseJsonWithCastAuto(Doc, "cache-type-k", + GraphRef.Params.cache_type_k); + parseJsonWithCastAuto(Doc, "cache-type-v", + GraphRef.Params.cache_type_v); + + parseJsonWithCastAuto(Doc, "embd-normalize", + GraphRef.Params.embd_normalize); + parseJsonWithCastAuto(Doc, "embd-out", + GraphRef.Params.embd_out); + + parseJsonWithCastAuto(Doc, "embd-sep", + GraphRef.Params.embd_sep); + parseJsonWithProcessorAuto( + Doc, "reranking", [&GraphRef](const bool &) -> bool { + GraphRef.Params.embedding = true; + GraphRef.Params.pooling_type = LLAMA_POOLING_TYPE_RANK; + return true; + }); + parseJsonWithCastAuto(Doc, "port", GraphRef.Params.port); + parseJsonWithCastAuto(Doc, "timeout-read", + GraphRef.Params.timeout_read); + parseJsonWithCastAuto(Doc, "timeout-write", + GraphRef.Params.timeout_write); + parseJsonWithCastAuto(Doc, "n-threads-http", + GraphRef.Params.n_threads_http); + parseJsonWithCastAuto(Doc, "n-cache-reuse", + GraphRef.Params.n_cache_reuse); + parseJsonWithCastAuto(Doc, "hostname", + GraphRef.Params.hostname); + parseJsonWithCastAuto(Doc, "public-path", + GraphRef.Params.public_path); + parseJsonWithCastAuto(Doc, "chat-template", + GraphRef.Params.chat_template); + parseJsonAuto(Doc, "enable-chat-template", + GraphRef.Params.enable_chat_template); + parseJsonWithCastAuto(Doc, "ssl-file-key", + GraphRef.Params.ssl_file_key); + parseJsonWithCastAuto(Doc, "ssl-file-cert", + GraphRef.Params.ssl_file_cert); + parseJsonAuto(Doc, "webui", GraphRef.Params.webui); + parseJsonAuto(Doc, "endpoint-slots", GraphRef.Params.endpoint_slots); + parseJsonAuto(Doc, "endpoint-props", GraphRef.Params.endpoint_props); + parseJsonAuto(Doc, "endpoint-metrics", + GraphRef.Params.endpoint_metrics); + parseJsonAuto(Doc, "log-json", GraphRef.Params.log_json); + + // Slot parameters + parseJsonWithCastAuto(Doc, "slot-save-path", + GraphRef.Params.slot_save_path); + + parseJsonWithCastAuto(Doc, "slot-prompt-similarity", + GraphRef.Params.slot_prompt_similarity); + parseJsonAuto(Doc, "is-pp-shared", GraphRef.Params.is_pp_shared); + parseJsonWithProcessorAuto( + Doc, "n-pp", [&GraphRef](const int64_t &NPP) -> bool { + GraphRef.Params.n_pp.emplace_back(NPP); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "n-tg", [&GraphRef](const int64_t &NTG) -> bool { + GraphRef.Params.n_tg.emplace_back(NTG); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "n-pl", [&GraphRef](const int64_t &NPL) -> bool { + GraphRef.Params.n_pl.emplace_back(NPL); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "context-files", + [&GraphRef](const std::string_view &ContextFile) -> bool { + GraphRef.Params.context_files.emplace_back(ContextFile); + return true; + }); + parseJsonWithCastAuto(Doc, "chunk-size", + GraphRef.Params.chunk_size); + parseJsonWithCastAuto(Doc, "chunk-separator", + GraphRef.Params.chunk_separator); + parseJsonWithCastAuto(Doc, "n-junk", GraphRef.Params.n_junk); + parseJsonWithCastAuto(Doc, "i-pos", GraphRef.Params.i_pos); + parseJsonWithCastAuto(Doc, "out-file", + GraphRef.Params.out_file); + parseJsonWithCastAuto(Doc, "n-out-freq", + GraphRef.Params.n_out_freq); + parseJsonWithCastAuto(Doc, "n-save-freq", + GraphRef.Params.n_save_freq); + parseJsonWithCastAuto(Doc, "i-chunk", GraphRef.Params.i_chunk); + parseJsonAuto(Doc, "process-output", GraphRef.Params.process_output); + parseJsonAuto(Doc, "compute-ppl", GraphRef.Params.compute_ppl); + parseJsonWithCastAuto(Doc, "n-pca-batch", + GraphRef.Params.n_pca_batch); + parseJsonWithCastAuto(Doc, "n-pca-iterations", + GraphRef.Params.n_pca_iterations); + parseJsonWithProcessorAuto( + Doc, "cvector-dimre-method", + [&GraphRef](const std::string_view &Method) -> bool { + if (Method == "DIMRE_METHOD_PCA") { + GraphRef.Params.cvector_dimre_method = + dimre_method::DIMRE_METHOD_PCA; + return true; + } + if (Method == "DIMRE_METHOD_MEAN") { + GraphRef.Params.cvector_dimre_method = + dimre_method::DIMRE_METHOD_MEAN; + return true; + } + return false; + }); + parseJsonWithCastAuto( + Doc, "cvector-positive-file", GraphRef.Params.cvector_positive_file); + parseJsonWithCastAuto( + Doc, "cvector-negative-file", GraphRef.Params.cvector_negative_file); + parseJsonAuto(Doc, "spm-infill", GraphRef.Params.spm_infill); + parseJsonWithCastAuto(Doc, "out-file", + GraphRef.Params.out_file); + parseJsonAuto(Doc, "batched-bench-output-jsonl", + GraphRef.Params.batched_bench_output_jsonl); + } catch (const ErrNo &Error) { + return Error; + } + // The tensor buffer overrides should be terminated with an empty pattern. + if (!GraphRef.TensorBuftOverrides.empty()) { + for (const std::string &Override : GraphRef.TensorBuftOverrides) { + GraphRef.Params.tensor_buft_overrides.push_back( + {Override.c_str(), ggml_backend_cpu_buffer_type()}); + } + GraphRef.Params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } + + if (GraphRef.TextToSpeech) { + GraphRef.Params.sampling.top_k = 4; + GraphRef.Params.sampling.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + } + // Check if the model parameters are updated. + if (IsModelUpdated && PrevNGPULayers != GraphRef.Params.n_gpu_layers) { + *IsModelUpdated = true; + } + + // Check if the context parameters are updated. + if (IsContextUpdated && PrevEmbedding != GraphRef.Params.embedding) { + *IsContextUpdated = true; + } + + // Check if the sampler parameters are updated. + if (IsSamplerUpdated && + (PrevTemp != GraphRef.Params.sampling.temp || + PrevTopP != GraphRef.Params.sampling.top_p || + PrevRepeatPenalty != GraphRef.Params.sampling.penalty_repeat || + PrevPresencePenalty != GraphRef.Params.sampling.penalty_present || + PrevFrequencyPenalty != GraphRef.Params.sampling.penalty_freq || + PrevGrammar != common_grammar_value(GraphRef.Params.sampling.grammar) || + PrevSeed != GraphRef.Params.sampling.seed)) { + *IsSamplerUpdated = true; + } + + return ErrNo::Success; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.h b/plugins/wasi_nn/GGML/metadata/metadata_parser.h new file mode 100644 index 00000000..40971146 --- /dev/null +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.h @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "GGML/core/ggml_core.h" +#include "simdjson.h" +#include "wasinntypes.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// JSON parsing helper template +template +T getJsonValue(const simdjson::dom::element &Doc, std::string_view Key) { + if (Doc.at_key(Key).error() == simdjson::SUCCESS) { + T Value{}; + auto Err = Doc[Key].get().get(Value); + if (Err) { + std::string Msg = fmt::format("Unable to retrieve the {} option.", Key); + spdlog::error("[WASI-NN] GGML backend: {}", Msg); + throw ErrNo(ErrNo::InvalidArgument); + } + return Value; + } + throw ErrNo(ErrNo::NotFound); +} + +template +void parseJsonAuto(const simdjson::dom::element &Doc, std::string_view Key, + T &Var) { + try { + auto Result = getJsonValue(Doc, Key); + Var = Result; + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +template +void parseJsonWithCastAuto(const simdjson::dom::element &Doc, + std::string_view Key, ToType &Var) { + try { + auto Result = getJsonValue(Doc, Key); + Var = static_cast(Result); + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +template +void parseJsonWithProcessorAuto(const simdjson::dom::element &Doc, + std::string_view Key, Processor Proc) { + try { + auto Result = getJsonValue(Doc, Key); + if (!Proc(Result)) { + throw ErrNo{ErrNo::InvalidArgument}; + } + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +} // namespace +ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, + const std::string &Metadata, bool *IsModelUpdated = nullptr, + bool *IsContextUpdated = nullptr, + bool *IsSamplerUpdated = nullptr) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/tts/tts_core.cpp b/plugins/wasi_nn/GGML/tts/tts_core.cpp new file mode 100644 index 00000000..86d332bd --- /dev/null +++ b/plugins/wasi_nn/GGML/tts/tts_core.cpp @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tts_core.h" +#include "host/wasi/vfs_io.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// TTS related functions. +void fillHannWindow(int Length, bool Periodic, float *Output) { + int Offset = -1; + float Pi = static_cast(std::acos(-1)); + if (Periodic) { + Offset = 0; + } + for (int I = 0; I < Length; I++) { + double Value = + 0.5 * + (1.0 - cosf(static_cast((2.0 * Pi * I) / (Length + Offset)))); + Output[I] = static_cast(Value); + } +} + +void twiddle(float *Real, float *Imag, int K, int N) { + float Pi = static_cast(std::acos(-1)); + float Angle = 2 * Pi * K / N; + *Real = cos(Angle); + *Imag = sin(Angle); +} + +void irfft(int N, const float *InpCplx, float *OutReal) { + int NN = N / 2 + 1; + + std::vector RealInput(NN); + std::vector ImagInput(NN); + for (int I = 0; I < NN; ++I) { + RealInput[I] = InpCplx[2 * I]; + ImagInput[I] = InpCplx[2 * I + 1]; + } + + std::vector RealOutput(N); + std::vector ImagOutput(N); + + for (int K = 0; K < N; ++K) { + RealOutput[K] = 0.0f; + ImagOutput[K] = 0.0f; + for (int M = 0; M < NN; ++M) { + float TwiddleReal; + float TwiddleImag; + + twiddle(&TwiddleReal, &TwiddleImag, K * M, N); + + RealOutput[K] += RealInput[M] * TwiddleReal - ImagInput[M] * TwiddleImag; + ImagOutput[K] += RealInput[M] * TwiddleImag + ImagInput[M] * TwiddleReal; + } + } + + for (int I = 0; I < N; ++I) { + OutReal[I] = RealOutput[I] / NN; + } +} + +void fold(const std::vector &Data, int64_t NOut, int64_t NWin, + int64_t NHop, int64_t NPad, std::vector &Output) { + int64_t OutputHeight = NOut; + int64_t KernelW = NWin; + int64_t StrideW = NHop; + int64_t Width = NOut; + + Output.resize(Width, 0.0f); + + int64_t ColIdx = 0; + for (int64_t WCol = 0; WCol < Width; ++WCol) { + int64_t Start = WCol * StrideW - NPad; + int64_t End = Start + KernelW; + + for (int64_t WIm = Start; WIm < End; ++WIm) { + if (WIm >= 0 && WIm < OutputHeight && + ColIdx < static_cast(Data.size())) { + Output[WIm] += Data[ColIdx]; + } + ColIdx++; + } + } + + Output.resize(NOut - 2 * NPad); +} + +std::vector embdToAudio(const float *Embd, const int NCodes, + const int NEmbd, const int NThread) { + const int NFft = 1280; + const int NHop = 320; + const int NWin = 1280; + const int NPad = (NWin - NHop) / 2; + const int NOut = (NCodes - 1) * NHop + NWin; + + std::vector Hann(NFft); + + fillHannWindow(static_cast(Hann.size()), true, Hann.data()); + + int NSpec = NEmbd * NCodes; + + std::vector E(NSpec); + std::vector S(NSpec); + std::vector ST(NSpec); + + for (int L = 0; L < NCodes; ++L) { + for (int K = 0; K < NEmbd; ++K) { + E[K * NCodes + L] = Embd[L * NEmbd + K]; + } + } + + for (int K = 0; K < NEmbd / 2; ++K) { + for (int L = 0; L < NCodes; ++L) { + float Mag = E[(K)*NCodes + L]; + float Phi = E[(K + NEmbd / 2) * NCodes + L]; + + Mag = exp(Mag); + + if (Mag > 1e2) { + Mag = 1e2; + } + S[2 * (K * NCodes + L) + 0] = Mag * cosf(Phi); + S[2 * (K * NCodes + L) + 1] = Mag * sinf(Phi); + } + } + + for (int L = 0; L < NCodes; ++L) { + for (int K = 0; K < NEmbd / 2; ++K) { + ST[L * NEmbd + 2 * K + 0] = S[2 * (K * NCodes + L) + 0]; + ST[L * NEmbd + 2 * K + 1] = S[2 * (K * NCodes + L) + 1]; + } + } + + std::vector Res(NCodes * NFft); + std::vector Hann2(NCodes * NFft); + + std::vector Workers(NThread); + for (int I = 0; I < NThread; ++I) { + Workers[I] = std::thread([&, I]() { + for (int L = I; L < NCodes; L += NThread) { + irfft(NFft, ST.data() + L * NEmbd, Res.data() + L * NFft); + for (int J = 0; J < NFft; ++J) { + Res[L * NFft + J] *= Hann[J]; + Hann2[L * NFft + J] = Hann[J] * Hann[J]; + } + } + }); + } + for (int I = 0; I < NThread; ++I) { + Workers[I].join(); + } + + std::vector Audio; + std::vector Env; + + fold(Res, NOut, NWin, NHop, NPad, Audio); + fold(Hann2, NOut, NWin, NHop, NPad, Env); // TODO: can be done once + + for (size_t I = 0; I < Audio.size(); ++I) { + Audio[I] /= Env[I]; + } + + return Audio; +} + +struct WavHeader { + char Riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t ChunkSize; + char Wave[4] = {'W', 'A', 'V', 'E'}; + char Fmt[4] = {'f', 'm', 't', ' '}; + uint32_t FmtChunkSize = 16; + uint16_t AudioFormat = 1; // PCM + uint16_t NumChannels = 1; // Mono + uint32_t SampleRate; + uint32_t ByteRate; + uint16_t BlockAlign; + uint16_t BitsPerSample = 16; + char Data[4] = {'d', 'a', 't', 'a'}; + uint32_t DataSize; +}; + +std::vector audioDataToWav(const std::vector &Data, + int SampleRate) { + std::vector WavData; + WavHeader Header; + Header.SampleRate = SampleRate; + Header.ByteRate = + Header.SampleRate * Header.NumChannels * (Header.BitsPerSample / 8); + Header.BlockAlign = Header.NumChannels * (Header.BitsPerSample / 8); + Header.DataSize = + static_cast(Data.size() * (Header.BitsPerSample / 8)); + Header.ChunkSize = 36 + Header.DataSize; + + WavData.insert(WavData.end(), reinterpret_cast(&Header), + reinterpret_cast(&Header) + sizeof(Header)); + + for (const auto &Sample : Data) { + int16_t PCMSample = + static_cast(std::clamp(Sample * 32767.0, -32768.0, 32767.0)); + WavData.insert(WavData.end(), reinterpret_cast(&PCMSample), + reinterpret_cast(&PCMSample) + sizeof(PCMSample)); + } + + return WavData; +} + +// Convert a number less than 1000 to words +std::string convertLessThanThousand(int Num) { + std::string Result; + + if (Num >= 100) { + Result += Ones.at(Num / 100) + " hundred "; + Num %= 100; + } + + if (Num >= 20) { + Result += Tens.at(Num / 10); + if (Num % 10 > 0) { + Result += "-" + Ones.at(Num % 10); + } + } else if (Num > 0) { + Result += Ones.at(Num); + } + + return Result; +} + +std::string numberToWords(const std::string &NumberStr) { + try { + size_t DecimalPos = NumberStr.find('.'); + std::string IntegerPart = NumberStr.substr(0, DecimalPos); + + int IntNumber = std::stoi(IntegerPart); + std::string Result; + + if (IntNumber == 0) { + Result = "zero"; + } else { + if (IntNumber >= 1000000000) { + int Billions = IntNumber / 1000000000; + Result += convertLessThanThousand(Billions) + " billion "; + IntNumber %= 1000000000; + } + + if (IntNumber >= 1000000) { + int Millions = IntNumber / 1000000; + Result += convertLessThanThousand(Millions) + " million "; + IntNumber %= 1000000; + } + + if (IntNumber >= 1000) { + int Thousands = IntNumber / 1000; + Result += convertLessThanThousand(Thousands) + " thousand "; + IntNumber %= 1000; + } + + if (IntNumber > 0) { + Result += convertLessThanThousand(IntNumber); + } + } + + // Handle decimal part + if (DecimalPos != std::string::npos) { + Result += " point"; + std::string DecimalPart = NumberStr.substr(DecimalPos + 1); + for (char Digit : DecimalPart) { + Result += " " + Ones.at(Digit - '0'); + } + } + + return Result; + } catch (const std::exception &) { + // Skip if fails + return " "; + } +} + +std::string replaceNumbersWithWords(const std::string &InputText) { + std::regex NumberPattern(R"(\d+(\.\d+)?)"); + std::string Result; + auto It = + std::sregex_iterator(InputText.begin(), InputText.end(), NumberPattern); + auto End = std::sregex_iterator(); + + size_t LastPos = 0; + for (std::sregex_iterator I = It; I != End; ++I) { + const std::smatch &Match = *I; + Result.append(InputText, LastPos, Match.position() - LastPos); + Result.append(numberToWords(Match.str())); + LastPos = Match.position() + Match.length(); + } + Result.append(InputText, LastPos); + + return Result; +} + +} // namespace + +std::vector processTTSPrompt(WasiNNEnvironment &Env, + Graph &GraphRef, + std::string &Prompt) noexcept { + // Use the custom speaker profile if available. + TTSSpeakerProfile SpeakerProfile = TTSDefaultSpeakerProfile; + if (!GraphRef.TTSSpeakerFilePath.empty()) { + std::optional SpeakerProfileOpt = + getSpeakerProfileFromFile(GraphRef.TTSSpeakerFilePath, Env); + if (SpeakerProfileOpt.has_value()) { + SpeakerProfile = *SpeakerProfileOpt; + } else { + RET_ERROR( + {}, + "processTTSPrompt: Failed to load speaker profile from file: {}"sv, + GraphRef.TTSSpeakerFilePath); + } + } + std::string ProcessedPrompt = processTTSPromptText(Prompt); + std::vector Result, TmpTokens; + Result = common_tokenize(GraphRef.LlamaContext.get(), "<|im_start|>\n", + /* add_special */ true, + /* parse_special */ true); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Text, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), ProcessedPrompt, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), "<|text_end|>\n", + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Data, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + + return Result; +} + +// Based on: +// https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 +// https://github.com/ggerganov/llama.cpp/blob/b4488/examples/tts/tts.cpp#L374 +std::string processTTSPromptText(const std::string &Text) { + std::string ProcessedText = replaceNumbersWithWords(Text); + + std::transform( + ProcessedText.begin(), ProcessedText.end(), ProcessedText.begin(), + [](unsigned char C) { return static_cast(::tolower(C)); }); + + std::regex SpecialChars(R"([-_/,\.\\])"); + ProcessedText = std::regex_replace(ProcessedText, SpecialChars, " "); + + std::regex NonAlpha(R"([^a-z\s])"); + ProcessedText = std::regex_replace(ProcessedText, NonAlpha, ""); + + std::regex MultipleSpaces(R"(\s+)"); + ProcessedText = std::regex_replace(ProcessedText, MultipleSpaces, " "); + + ProcessedText = + std::regex_replace(ProcessedText, std::regex(R"(^\s+|\s+$)"), ""); + + ProcessedText = + std::regex_replace(ProcessedText, std::regex(R"(\s)"), "<|text_sep|>"); + + return ProcessedText; +} + +std::optional +getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env) { + WasmEdge::FStream::IFStream JsonFile(FilePath, Env.getEnv()); + if (!JsonFile.is_open()) { + return std::nullopt; + } + nlohmann::json JsonData; + JsonFile >> JsonData; + JsonFile.close(); + + // Initialize the outputs. + std::string AudioOutputText = "<|audio_start|>\n"; + std::string TextOutput = "<|text_start|>"; + + // Iterate through each word in the JSON data + for (const auto &WordData : JsonData["words"]) { + std::string Word = WordData["word"]; + double Duration = WordData["duration"]; + std::vector Codes = WordData["codes"]; + + // Create the audio output entry + std::ostringstream WordEntry; + WordEntry << Word << "<|t_" << std::fixed << std::setprecision(2) + << Duration << "|><|code_start|>"; + for (const auto &Code : Codes) { + WordEntry << "<|" << Code << "|>"; + } + WordEntry << "<|code_end|>\n"; + AudioOutputText += WordEntry.str(); + + // Create the text output entry + TextOutput += Word + "<|text_sep|>"; + } + + return TTSSpeakerProfile{TextOutput, AudioOutputText}; +} + +// TextToSpeech function that generates voice data from codes. +ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, + Context &CxtRef) noexcept { + // Remove all non-audio tokens. + CxtRef.LlamaOutputTokens.erase( + std::remove_if(CxtRef.LlamaOutputTokens.begin(), + CxtRef.LlamaOutputTokens.end(), + [](llama_token T) { return T < 151672 || T > 155772; }), + CxtRef.LlamaOutputTokens.end()); + + // Adjust the token values for audio data. + for (llama_token &Token : CxtRef.LlamaOutputTokens) { + Token -= 151672; + } + + // Put codes into batch. + const uint32_t NCodes = + static_cast(CxtRef.LlamaOutputTokens.size()); + llama_batch TTSBatch = + llama_batch_init(NCodes, /* embd */ 0, /* n_seq_max */ 1); + for (uint32_t I = 0; I < NCodes; ++I) { + common_batch_add(TTSBatch, CxtRef.LlamaOutputTokens[I], I, + /* seq_ids */ {0}, /* logits */ true); + } + if (llama_decode(GraphRef.TTSContext.get(), TTSBatch) != 0) { + RET_ERROR(ErrNo::RuntimeError, "codesToSpeech: fail to eval."sv) + } + llama_batch_free(TTSBatch); + + // Get embeddings. + const int NEmbd = llama_model_n_embd(GraphRef.TTSModel.get()); + const float *Embd = llama_get_embeddings(GraphRef.TTSContext.get()); + + // Embeddings to audio. + std::vector AudioData = + embdToAudio(Embd, NCodes, NEmbd, + static_cast(GraphRef.Params.cpuparams.n_threads)); + + // Zero out first 0.25 seconds of audio. + const uint32_t SamplingRate = 24000; + for (uint32_t I = 0; I < SamplingRate / 4; ++I) { + AudioData[I] = 0.0f; + } + + // Convert audio data to WAV and put it into the output buffer. + CxtRef.LlamaOutputs = audioDataToWav(AudioData, SamplingRate); + + // Save .wav file if path is provided. + if (!GraphRef.TTSOutputFilePath.empty()) { + WasmEdge::FStream::OFStream File(GraphRef.TTSOutputFilePath, + std::ios_base::out | std::ios_base::binary, + Env.getEnv()); + if (!File) { + RET_ERROR(ErrNo::RuntimeError, + "codesToSpeech: Failed to open file '{}' for writing"sv, + GraphRef.TTSOutputFilePath); + } + File.write(reinterpret_cast(CxtRef.LlamaOutputs.data()), + CxtRef.LlamaOutputs.size()); + File.close(); + } + + return ErrNo::Success; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/tts/tts_core.h b/plugins/wasi_nn/GGML/tts/tts_core.h new file mode 100644 index 00000000..f8ff281c --- /dev/null +++ b/plugins/wasi_nn/GGML/tts/tts_core.h @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "GGML/core/ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +struct TTSSpeakerProfile { + std::string Text; + std::string Data; +}; + +// TTS function to process the prompt text. +// clang-format off +const TTSSpeakerProfile TTSDefaultSpeakerProfile = { + // Speaker profile from edwko/OuteTTS (en_female_1.json). + "<|text_start|>uhm<|text_sep|>now<|text_sep|>being<|text_sep|>the<|text_sep|>one<|text_sep|>to<|text_sep|>say<|text_sep|>i<|text_sep|>know<|text_sep|>the<|text_sep|>worst<|text_sep|>of<|text_sep|>you<|text_sep|>and<|text_sep|>ive<|text_sep|>been<|text_sep|>directly<|text_sep|>affected<|text_sep|>by<|text_sep|>people<|text_sep|>like<|text_sep|>you<|text_sep|>but<|text_sep|>its<|text_sep|>a<|text_sep|>clean<|text_sep|>slate<|text_sep|>with<|text_sep|>me<|text_sep|>buddy<|text_sep|>you<|text_sep|>know<|text_sep|>like<|text_sep|>thats<|text_sep|>really<|text_sep|>powerful<|text_sep|>in<|text_sep|>and<|text_sep|>of<|text_sep|>itself<|text_sep|>", + "<|audio_start|>\nuhm<|t_0.36|><|code_start|><|447|><|223|><|967|><|301|><|965|><|827|><|393|><|908|><|764|><|1167|><|711|><|1222|><|324|><|1318|><|806|><|498|><|1198|><|1127|><|1178|><|916|><|1234|><|1411|><|1428|><|706|><|427|><|1605|><|1578|><|code_end|>\nnow<|t_0.36|><|code_start|><|1049|><|327|><|385|><|1070|><|732|><|1480|><|450|><|1025|><|1469|><|174|><|1013|><|1710|><|1674|><|775|><|771|><|251|><|778|><|1400|><|897|><|1487|><|366|><|441|><|1000|><|393|><|271|><|1000|><|768|><|code_end|>\nbeing<|t_0.27|><|code_start|><|926|><|406|><|1457|><|437|><|1231|><|672|><|1785|><|521|><|1179|><|1559|><|198|><|1086|><|733|><|122|><|1344|><|845|><|348|><|1389|><|470|><|1773|><|code_end|>\nthe<|t_0.08|><|code_start|><|1775|><|562|><|768|><|1222|><|768|><|963|><|code_end|>\none<|t_0.21|><|code_start|><|1757|><|744|><|144|><|1610|><|655|><|616|><|1317|><|225|><|1325|><|913|><|1342|><|992|><|1018|><|80|><|1777|><|883|><|code_end|>\nto<|t_0.08|><|code_start|><|487|><|1363|><|1682|><|1426|><|655|><|1483|><|code_end|>\nsay<|t_0.27|><|code_start|><|1644|><|1804|><|731|><|273|><|1592|><|731|><|1523|><|1404|><|984|><|1207|><|430|><|1132|><|1123|><|768|><|1116|><|829|><|1082|><|1095|><|440|><|1162|><|code_end|>\ni<|t_0.33|><|code_start|><|1330|><|335|><|1162|><|1155|><|308|><|1162|><|1150|><|1481|><|612|><|674|><|712|><|1745|><|1188|><|1787|><|1135|><|1275|><|1237|><|1143|><|408|><|1063|><|393|><|927|><|1298|><|132|><|1686|><|code_end|>\nknow<|t_0.27|><|code_start|><|983|><|1677|><|586|><|1528|><|1435|><|835|><|1396|><|706|><|987|><|22|><|1172|><|218|><|1404|><|1001|><|521|><|1389|><|775|><|1416|><|877|><|120|><|code_end|>\nthe<|t_0.16|><|code_start|><|916|><|1756|><|513|><|1245|><|1392|><|89|><|1266|><|12|><|1045|><|1075|><|904|><|35|><|code_end|>\nworst<|t_0.32|><|code_start|><|1607|><|174|><|1231|><|144|><|932|><|490|><|771|><|1504|><|798|><|674|><|364|><|80|><|1314|><|1636|><|449|><|1704|><|713|><|1795|><|968|><|1527|><|1302|><|1529|><|1176|><|795|><|code_end|>\nof<|t_0.12|><|code_start|><|1193|><|1205|><|390|><|1128|><|1091|><|883|><|322|><|377|><|1070|><|code_end|>\nyou<|t_0.17|><|code_start|><|1016|><|1332|><|926|><|281|><|927|><|1368|><|1687|><|918|><|67|><|1638|><|1317|><|1265|><|1770|><|code_end|>\nand<|t_0.28|><|code_start|><|1129|><|1633|><|1373|><|1207|><|405|><|879|><|1030|><|1253|><|1071|><|612|><|724|><|1770|><|665|><|1046|><|1351|><|1450|><|1541|><|1384|><|111|><|1477|><|284|><|code_end|>\nive<|t_0.35|><|code_start|><|674|><|266|><|89|><|1333|><|1183|><|1526|><|1143|><|883|><|1135|><|732|><|827|><|1119|><|594|><|1261|><|1024|><|1347|><|92|><|1392|><|825|><|1710|><|1289|><|1598|><|1070|><|1525|><|1442|><|555|><|code_end|>\nbeen<|t_0.17|><|code_start|><|1461|><|194|><|337|><|1128|><|188|><|892|><|848|><|1280|><|959|><|754|><|231|><|649|><|1304|><|code_end|>\ndirectly<|t_0.87|><|code_start|><|1030|><|353|><|570|><|1331|><|470|><|1832|><|1362|><|1809|><|1383|><|101|><|325|><|1557|><|1242|><|1512|><|180|><|227|><|1242|><|643|><|209|><|464|><|171|><|1219|><|174|><|1723|><|734|><|118|><|1269|><|643|><|209|><|187|><|612|><|1231|><|68|><|567|><|1242|><|505|><|319|><|1268|><|794|><|678|><|40|><|1286|><|470|><|1454|><|199|><|965|><|188|><|300|><|1234|><|1125|><|794|><|1289|><|1224|><|257|><|469|><|1121|><|101|><|823|><|1769|><|1683|><|95|><|255|><|59|><|67|><|832|><|code_end|>\naffected<|t_0.44|><|code_start|><|510|><|873|><|787|><|1228|><|771|><|1428|><|501|><|751|><|696|><|258|><|845|><|1818|><|1112|><|498|><|1111|><|985|><|1073|><|832|><|1427|><|168|><|163|><|447|><|119|><|567|><|1626|><|1820|><|903|><|635|><|1060|><|10|><|1632|><|35|><|1635|><|code_end|>\nby<|t_0.19|><|code_start|><|144|><|144|><|460|><|185|><|1112|><|1044|><|498|><|1192|><|656|><|1333|><|1001|><|1186|><|1186|><|454|><|code_end|>\npeople<|t_0.48|><|code_start|><|1260|><|747|><|351|><|526|><|612|><|1151|><|1262|><|1791|><|344|><|1752|><|1547|><|930|><|1302|><|1703|><|1289|><|92|><|1407|><|1482|><|508|><|1431|><|355|><|1696|><|337|><|199|><|1157|><|223|><|464|><|568|><|845|><|411|><|826|><|718|><|1786|><|545|><|712|><|580|><|code_end|>\nlike<|t_0.32|><|code_start|><|630|><|532|><|526|><|607|><|526|><|839|><|1305|><|660|><|459|><|339|><|717|><|1178|><|1148|><|687|><|149|><|1390|><|229|><|199|><|513|><|712|><|1451|><|731|><|582|><|1551|><|code_end|>\nyou<|t_0.21|><|code_start|><|1389|><|954|><|1781|><|1047|><|1236|><|930|><|809|><|1621|><|1268|><|384|><|242|><|587|><|869|><|816|><|1680|><|405|><|code_end|>\nbut<|t_0.59|><|code_start|><|1089|><|1590|><|908|><|80|><|594|><|1046|><|1706|><|1025|><|1150|><|405|><|548|><|893|><|1285|><|464|><|301|><|939|><|643|><|23|><|285|><|161|><|209|><|453|><|72|><|167|><|417|><|244|><|151|><|643|><|391|><|199|><|651|><|1023|><|337|><|1010|><|54|><|331|><|1167|><|756|><|388|><|934|><|1060|><|18|><|1624|><|1060|><|code_end|>\nits<|t_0.16|><|code_start|><|1102|><|183|><|1199|><|1258|><|1285|><|35|><|659|><|180|><|426|><|1587|><|1733|><|942|><|code_end|>\na<|t_0.04|><|code_start|><|791|><|1012|><|818|><|code_end|>\nclean<|t_0.61|><|code_start|><|1819|><|976|><|163|><|447|><|316|><|223|><|763|><|457|><|1208|><|1808|><|1697|><|1162|><|1660|><|1833|><|1054|><|1734|><|1121|><|1309|><|1643|><|924|><|1677|><|1548|><|869|><|1268|><|223|><|674|><|111|><|792|><|1670|><|912|><|174|><|1554|><|90|><|80|><|1563|><|1621|><|1698|><|1544|><|992|><|988|><|175|><|793|><|1661|><|1026|><|80|><|1761|><|code_end|>\nslate<|t_0.40|><|code_start|><|1802|><|322|><|1689|><|1577|><|1302|><|1552|><|1529|><|1722|><|1580|><|582|><|1642|><|1529|><|1020|><|582|><|1538|><|970|><|437|><|1141|><|1477|><|988|><|335|><|1611|><|922|><|1558|><|1120|><|1189|><|423|><|188|><|171|><|562|><|code_end|>\nwith<|t_0.15|><|code_start|><|963|><|1347|><|1274|><|747|><|1230|><|712|><|1408|><|1290|><|957|><|1279|><|258|><|code_end|>\nme<|t_0.09|><|code_start|><|638|><|1058|><|174|><|1452|><|1038|><|894|><|1571|><|code_end|>\nbuddy<|t_0.32|><|code_start|><|1003|><|130|><|1341|><|938|><|40|><|804|><|167|><|89|><|1456|><|1189|><|1155|><|1171|><|1434|><|1077|><|1029|><|1455|><|1622|><|1037|><|163|><|1411|><|1165|><|1463|><|837|><|1202|><|code_end|>\nyou<|t_0.36|><|code_start|><|1354|><|1165|><|615|><|1588|><|1192|><|1445|><|1033|><|982|><|401|><|1079|><|684|><|1570|><|266|><|31|><|420|><|163|><|893|><|845|><|905|><|1827|><|1804|><|153|><|627|><|243|><|1179|><|298|><|1147|><|code_end|>\nknow<|t_0.19|><|code_start|><|163|><|1542|><|1366|><|698|><|1753|><|206|><|916|><|1499|><|245|><|665|><|600|><|894|><|587|><|1741|><|code_end|>\nlike<|t_0.24|><|code_start|><|1106|><|1280|><|1062|><|1304|><|945|><|809|><|598|><|104|><|1001|><|822|><|965|><|189|><|693|><|1810|><|1293|><|199|><|1277|><|44|><|code_end|>\nthats<|t_0.24|><|code_start|><|121|><|1789|><|1443|><|370|><|1154|><|393|><|1178|><|1200|><|1264|><|424|><|1391|><|381|><|978|><|1346|><|704|><|1808|><|1579|><|1492|><|code_end|>\nreally<|t_0.56|><|code_start|><|1177|><|1761|><|1723|><|1360|><|1413|><|830|><|551|><|193|><|59|><|332|><|598|><|734|><|1684|><|1802|><|60|><|1590|><|353|><|89|><|1636|><|1396|><|893|><|143|><|455|><|1501|><|435|><|1082|><|621|><|1593|><|677|><|474|><|971|><|1513|><|913|><|828|><|1381|><|1148|><|1798|><|1186|><|1443|><|38|><|335|><|883|><|code_end|>\npowerful<|t_0.63|><|code_start|><|1773|><|458|><|1070|><|964|><|826|><|1220|><|1012|><|1738|><|1125|><|669|><|490|><|1169|><|922|><|958|><|1204|><|489|><|1001|><|886|><|1045|><|675|><|1471|><|1652|><|732|><|698|><|1124|><|480|><|897|><|1484|><|1028|><|35|><|594|><|1465|><|505|><|1669|><|436|><|851|><|1288|><|31|><|1501|><|1187|><|394|><|909|><|1541|><|1793|><|1720|><|922|><|840|><|code_end|>\nin<|t_0.16|><|code_start|><|1317|><|523|><|630|><|1343|><|1187|><|719|><|907|><|636|><|111|><|1524|><|188|><|1382|><|code_end|>\nand<|t_0.13|><|code_start|><|1074|><|922|><|1280|><|1496|><|1050|><|832|><|133|><|1435|><|1049|><|1774|><|code_end|>\nof<|t_0.12|><|code_start|><|960|><|1052|><|1192|><|1303|><|1112|><|970|><|417|><|60|><|1155|><|code_end|>\nitself<|t_0.47|><|code_start|><|1682|><|1209|><|1410|><|513|><|1222|><|861|><|167|><|406|><|1551|><|582|><|634|><|1529|><|786|><|1363|><|1578|><|1739|><|873|><|424|><|1041|><|1328|><|955|><|1110|><|1490|><|1424|><|1199|><|988|><|1162|><|1133|><|1193|><|978|><|470|><|832|><|963|><|1251|><|733|><|code_end|>", +}; +// clang-format on + +const std::map Ones = { + {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, + {4, "four"}, {5, "five"}, {6, "six"}, {7, "seven"}, + {8, "eight"}, {9, "nine"}, {10, "ten"}, {11, "eleven"}, + {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"}, {15, "fifteen"}, + {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}}; + +const std::map Tens = { + {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, + {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}}; + +std::vector embdToAudio(const float *Embd, const int NCodes, + const int NEmbd, const int NThread); +std::vector audioDataToWav(const std::vector &Data, + int SampleRate); +} // namespace + +std::string processTTSPromptText(const std::string &Text); +std::optional +getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env); + +std::vector processTTSPrompt(WasiNNEnvironment &Env, + Graph &GraphRef, + std::string &Prompt) noexcept; +ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, + Context &CxtRef) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/utils.cpp b/plugins/wasi_nn/GGML/utils.cpp new file mode 100644 index 00000000..29df8ab6 --- /dev/null +++ b/plugins/wasi_nn/GGML/utils.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "utils.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +// Helper to initialize a llama batch. +struct llama_batch allocBatch(int64_t NTokens, int64_t Embd, + int32_t NSeqMax) noexcept { + struct llama_batch Batch = llama_batch_init( + /* n_tokens_alloc */ static_cast(NTokens), + /* embd */ static_cast(Embd), + /* n_seq_max */ static_cast(NSeqMax)); + std::fill(Batch.n_seq_id, Batch.n_seq_id + NTokens, + static_cast(NSeqMax)); + for (int64_t I = 0; I < NTokens; I++) { + std::fill(Batch.seq_id[I], Batch.seq_id[I] + NSeqMax, 0); + } + std::fill(Batch.logits, Batch.logits + NTokens, false); + return Batch; +} + +// Get base64 image position if found in prompt. +std::optional> +findBase64ImagePayload(std::string_view Prompt, bool IsDebugLog) noexcept { + // Find `` + auto EndTagPos = Prompt.find(Base64ImageTagSuffix, PayloadPos); + if (EndTagPos == std::string::npos) { + LOG_DEBUG(IsDebugLog, "base64: image tag unclosed."sv) + return std::nullopt; + } + return std::make_tuple(BeginTagPos, PayloadPos, EndTagPos); +} + +// Extract base64 image payload and image type. Replace it with placeholder. +std::optional, std::string>> +extractBase64ImagePayload(std::string &Prompt, + std::tuple ImagePos, + const std::string_view Placeholder) noexcept { + // Locate the payload and image type. + size_t BeginTagPos = std::get<0>(ImagePos); + size_t TypePos = std::get<0>(ImagePos) + Base64ImageTagPrefix.size(); + size_t PayloadPos = std::get<1>(ImagePos); + size_t BeginBytePos = std::get<1>(ImagePos) + Base64ImageBytesPrefix.size(); + size_t EndTagPos = std::get<2>(ImagePos); + std::string_view Payload = + std::string_view(Prompt).substr(BeginBytePos, EndTagPos - BeginBytePos); + std::string ImageType = Prompt.substr(TypePos, PayloadPos - TypePos); + + // Decode the base64 payload. + auto RequiredBytes = base64::required_encode_size(Payload.size()); + std::vector ImageBytes(RequiredBytes); + try { + base64::decode(Payload.begin(), Payload.end(), ImageBytes.begin()); + } catch (const base64_error &E) { + RET_ERROR(std::make_pair(std::vector(), ""), + "base64: Error when calling base64::decode: {}"sv, E.what()) + } + + // Replace the base64 image with the placeholder. + Prompt.replace(BeginTagPos, + EndTagPos - BeginTagPos + Base64ImageTagSuffix.size(), + Placeholder); + return std::make_pair(ImageBytes, ImageType); +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/utils.h b/plugins/wasi_nn/GGML/utils.h new file mode 100644 index 00000000..943b22ba --- /dev/null +++ b/plugins/wasi_nn/GGML/utils.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC +#pragma once +#include "GGML/core/ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +const std::string_view Base64ImageTagPrefix = ""sv; +const std::string_view VisionPromptImagePlaceholder = ""sv; + +struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, + int32_t NSeqMax = 1) noexcept; +std::optional> +findBase64ImagePayload(std::string_view Prompt, + bool IsDebugLog = false) noexcept; +std::optional, std::string>> +extractBase64ImagePayload(std::string &Prompt, + std::tuple ImagePos, + const std::string_view Placeholder) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp new file mode 100644 index 00000000..137a9e35 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/activations.h" + +#include + +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core { + +mx::array gelu(mx::array X) { + // auto Result = X * (1 + mx::erf(X / std::sqrt(2))) / 2; + auto Result = X * + (mx::array({1}, X.dtype()) + + mx::erf(X / mx::array({std::sqrt(2)}, X.dtype()))) / + mx::array({2}, X.dtype()); + return Result; +} + +mx::array silu(mx::array X) { return X * mx::sigmoid(X); } + +mx::array geluApprox(mx::array X) { + return mx::array({0.5}, X.dtype()) * X * + (mx::array({1}, X.dtype()) + + mx::tanh(mx::array({std::sqrt(2.0 / M_PI)}, X.dtype()) * + (X + mx::array({0.044715}, X.dtype()) * + mx::power(X, mx::array({3}, X.dtype()))))); +} +} // namespace mlx::core +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/activations.h b/plugins/wasi_nn/MLX/mlx/activations.h new file mode 100644 index 00000000..15d0dfd4 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/activations.h @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" + +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core { + +mx::array gelu(mx::array X); + +mx::array silu(mx::array X); +mx::array geluApprox(mx::array X); +} // namespace mlx::core +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp new file mode 100644 index 00000000..0030bb17 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/base.h" +#include "model/utils.h" + +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array &Module::registerParameter(std::string ParamName, mx::array &&W) { + Parameters.insert({ParamName, W}); + return Parameters.at(ParamName); +} + +void Module::update(std::unordered_map NewParameters) { + for (auto &[K, V] : NewParameters) { + apply(K, V); + } +} + +std::shared_ptr Module::toQuantized( + int GroupSize, int Bits, const std::string &Prefix, + const std::unordered_map &LoadedWeights) { + auto NewPrefix = Prefix + Name + (Prefix.empty() && Name.empty() ? "" : "."); + for (auto &[K, V] : Submodules) { + if (V->hasQuantize()) { + auto Weights = V->Parameters.find("weight"); + if (Weights != V->Parameters.end() && !LoadedWeights.empty()) { + if (LoadedWeights.count(NewPrefix + V->Name + ".scales") == 0) { + continue; + } + } + if (Weights != V->Parameters.end() && + Weights->second.shape().back() % GroupSize != 0) { + continue; + } + } + V = V->toQuantized(GroupSize, Bits, + Prefix + Name + (Name.empty() ? "" : "."), + LoadedWeights); + } + return shared_from_this(); +} + +void Module::apply(std::string Key, mx::array Value) { + std::vector SplitKey = splitString(Key, '.'); + if (SplitKey.size() == 1) { + if (Parameters.find(Key) == Parameters.end()) { + spdlog::error("[WASI-NN] MLX backend: Unsupported weight: {}"sv, Key); + assumingUnreachable(); + } + this->Parameters.at(Key) = Value; + } else { + std::string LayerName = SplitKey[0]; + SplitKey.erase(SplitKey.begin()); + if (LayerName == "layers" || LayerName == "blocks") { + LayerName += "." + SplitKey[0]; + SplitKey.erase(SplitKey.begin()); + } + if (Submodules.find(LayerName) == Submodules.end()) { + spdlog::error("[WASI-NN] MLX backend: Unsupported Layer: {}"sv, + LayerName); + assumingUnreachable(); + } + Submodules.at(LayerName)->apply(joinString(SplitKey, '.'), Value); + } +} + +std::unordered_map +Module::getWeigts(const std::string &Prefix) { + std::unordered_map Weights; + auto NewPrefix = Prefix + Name; + for (auto &[K, V] : Submodules) { + auto Subweights = V->getWeigts(NewPrefix + (NewPrefix.empty() ? "" : ".")); + Weights.insert(Subweights.begin(), Subweights.end()); + } + for (auto &[K, V] : Parameters) { + Weights.insert({NewPrefix + (NewPrefix.empty() ? "" : ".") + K, V}); + } + return Weights; +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h new file mode 100644 index 00000000..971b4f32 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "common/errcode.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace std::literals::string_view_literals; + +namespace WasmEdge::Host::WASINN::MLX { + +namespace mx = mlx::core; + +namespace mlx::core::nn { + +class Module : public std::enable_shared_from_this { +public: + std::string Name; + std::unordered_map Parameters; + std::unordered_map> Submodules; + + virtual ~Module() = default; + + mx::array ®isterParameter(std::string ParamName, mx::array &&W); + + std::unordered_map + getWeigts(const std::string &Prefix = "model"); + + virtual std::shared_ptr toQuantized( + int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", + const std::unordered_map &LoadedWeights = {}); + + virtual bool hasQuantize() { return false; } + + void update(std::unordered_map NewParameters); + + void apply(std::string Key, mx::array Value); + + template + void registerModule(std::string ModuleName, std::shared_ptr M) { + using DecayedT = std::decay_t; + if (!std::is_base_of::value) { + spdlog::error("[WASI-NN] MLX backend: Invalid subModule."sv); + assumingUnreachable(); + } + + if (Submodules.find(ModuleName) == Submodules.end()) { + Submodules.insert({ModuleName, M}); + Submodules.at(ModuleName)->Name = ModuleName; + } else { + spdlog::error("[WASI-NN] MLX backend: Module already exists."sv); + assumingUnreachable(); + } + } + + template + void registerLayer(std::string ModuleName, + std::vector> &Layers) { + if (!std::is_base_of::value) { + spdlog::error("[WASI-NN] MLX backend: Invalid subModule."sv); + assumingUnreachable(); + } + for (size_t Idx = 0; Idx < Layers.size(); Idx++) { + registerModule(ModuleName + "." + std::to_string(Idx), Layers[Idx]); + } + } +}; + +} // namespace mlx::core::nn + +template void printVec(std::vector Ve) { + for (auto I : Ve) { + spdlog::debug("[WASI-NN] MLX backend: {} ."sv, I); + } +} + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/convolution.cpp b/plugins/wasi_nn/MLX/mlx/convolution.cpp new file mode 100644 index 00000000..4e38a69b --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/convolution.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "convolution.h" +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array Conv1d::forward(mx::array Input) { + auto Y = mx::conv1d(Input, Parameters.at("weight"), Stride, Padding, Dilation, + Groups); + if (Parameters.find("bias") != Parameters.end()) { + Y = Y + Parameters.at("bias"); + } + return Y; +} + +mx::array Conv2d::forward(mx::array Input) { + auto Y = mx::conv2d(Input, Parameters.at("weight"), Stride, Padding, Dilation, + Groups); + if (Parameters.find("bias") != Parameters.end()) { + Y = Y + Parameters.at("bias"); + } + return Y; +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/convolution.h b/plugins/wasi_nn/MLX/mlx/convolution.h new file mode 100644 index 00000000..db475a5f --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/convolution.h @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "base.h" + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class Conv1d : public nn::Module { + int Padding; + int Stride; + int Dilation; + int Groups; + +public: + Conv1d(int InChannels, int OutChannels, int KernelSize, int Stride = 1, + int Padding = 0, int Dilation = 1, int Groups = 1, bool Bias = true) + : Padding(Padding), Stride(Stride), Dilation(Dilation), Groups(Groups) { + double Scale = std::sqrt(1.0 / (InChannels * KernelSize)); + registerParameter( + "weight", mx::random::uniform(-Scale, Scale, + {OutChannels, InChannels, KernelSize})); + if (Bias) { + registerParameter("bias", mx::zeros({OutChannels})); + } + } + mx::array forward(mx::array Input); +}; + +class Conv2d : public nn::Module { + std::pair Padding; + std::pair Stride; + std::pair Dilation; + int Groups; + +public: + Conv2d(int InChannels, int OutChannels, int KernelSize, + std::pair Stride = {1, 1}, + std::pair Padding = {0, 0}, + std::pair Dilation = {1, 1}, int Groups = 1, + bool Bias = true) + : Padding(Padding), Stride(Stride), Dilation(Dilation), Groups(Groups) { + + if (InChannels % Groups != 0) { + // InChannels must be divisible by Groups. + assumingUnreachable(); + } + double Scale = std::sqrt(1.0 / (InChannels * KernelSize * KernelSize)); + registerParameter("weight", mx::random::uniform(-Scale, Scale, + {OutChannels, InChannels, + KernelSize, KernelSize})); + if (Bias) { + registerParameter("bias", mx::zeros({OutChannels})); + } + } + mx::array forward(mx::array Input); +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp new file mode 100644 index 00000000..be9367ef --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/embedding.h" +#include "mlx/quantized.h" + +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array Embedding::forward(mx::array Input) { + return take(Parameters.at("weight"), Input, 0); +} + +mx::array Embedding::asLinear(mx::array Input) { + return matmul(Input, transpose(Parameters.at("weight"))); +} + +std::shared_ptr +Embedding::toQuantized(int GroupSize, int Bits, const std::string &, + const std::unordered_map &) { + auto QuantModel = QuantizedEmbedding::fromEmbedding( + std::dynamic_pointer_cast(shared_from_this()), GroupSize, + Bits); + QuantModel->Name = Name; + return QuantModel; +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h new file mode 100644 index 00000000..66175d0e --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" + +#include +#include + +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class Embedding : public Module { +public: + Embedding() = default; + + Embedding(int NumEmbeddings, int Dims) { + const double Scale = std::sqrt(1.0 / Dims); + registerParameter("weight", + mx::random::normal({NumEmbeddings, Dims}, 0.0, Scale)); + } + + virtual mx::array forward(mx::array Input); + + mx::array asLinear(mx::array Input); + + std::shared_ptr + toQuantized(int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", + const std::unordered_map &LoadedWeights = + {}) override; + + virtual bool hasQuantize() override { return true; } +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp new file mode 100644 index 00000000..1de1ab79 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/linear.h" +#include "mlx/quantized.h" + +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array Linear::forward(mx::array Input) { + if (EnableBias) { + return mx::addmm(Parameters.at("bias"), Input, + transpose(Parameters.at("weight"))); + } + return matmul(Input, transpose(Parameters.at("weight"))); +} + +std::shared_ptr +Linear::toQuantized(int GroupSize, int Bits, const std::string &, + const std::unordered_map &) { + auto QuantModel = QuantizedLinear::fromLinear( + std::dynamic_pointer_cast(shared_from_this()), GroupSize, Bits); + QuantModel->Name = Name; + return QuantModel; +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h new file mode 100644 index 00000000..a8b07360 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" + +#include +#include +#include + +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class Linear : public Module { + bool EnableBias = true; + +public: + Linear() = default; + Linear(int InputDims, int OutputDims, bool EnableBias = true) + : EnableBias(EnableBias) { + const double Scale = std::sqrt(1.0 / InputDims); + registerParameter( + "weight", mx::random::uniform(-Scale, Scale, {OutputDims, InputDims})); + if (EnableBias) { + registerParameter("bias", mx::random::uniform(-Scale, Scale, + { + OutputDims, + })); + } + } + + virtual mx::array forward(mx::array Input); + std::shared_ptr + toQuantized(int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", + const std::unordered_map &LoadedWeights = + {}) override; + + virtual bool hasQuantize() override { return true; } +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.cpp b/plugins/wasi_nn/MLX/mlx/normalization.cpp new file mode 100644 index 00000000..dfc95205 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/normalization.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/normalization.h" + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array RMSNorm::forward(mx::array Input) { + return mx::fast::rms_norm(Input, Parameters.at("weight"), Eps); +} + +mx::array LayerNorm::forward(mx::array Input) { + std::optional Weight = {}; + std::optional Bias = {}; + if (Parameters.find("weight") != Parameters.end()) { + Weight = Parameters.at("weight"); + } + if (Parameters.find("bias") != Parameters.end()) { + Bias = Parameters.at("bias"); + } + + return mx::fast::layer_norm(Input, Weight, Bias, Eps); +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.h b/plugins/wasi_nn/MLX/mlx/normalization.h new file mode 100644 index 00000000..81ee7e3e --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/normalization.h @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class RMSNorm : public nn::Module { + float Eps; + +public: + RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); + } + mx::array forward(mx::array Input); +}; + +class LayerNorm : public nn::Module { + int Dims; + float Eps; + +public: + LayerNorm(int Dims, float Eps = 1e-5, bool Affine = true, bool Bias = true) + : Dims(Dims), Eps(Eps) { + if (Affine) { + registerParameter("weight", mx::ones({Dims})); + if (Bias) { + registerParameter("bias", mx::zeros({Dims})); + } + } + } + mx::array forward(mx::array Input); +}; + +} // namespace mlx::core::nn + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/pooling.cpp b/plugins/wasi_nn/MLX/mlx/pooling.cpp new file mode 100644 index 00000000..f81349d7 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/pooling.cpp @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "pooling.h" +#include "base.h" +#include "spdlog/spdlog.h" +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +namespace { +std::vector valueOrList(const int Value, int Len) { + return std::vector(Len, Value); +} + +std::vector> +makePaddingPairs(const std::vector &Pads) { + std::vector> Pairs; + for (int P : Pads) { + Pairs.push_back({P, P}); + } + return Pairs; +} + +std::vector> +makePadShape(const std::vector> &Padding) { + std::vector> PadShape; + PadShape.push_back({0, 0}); + for (auto &P : Padding) { + PadShape.push_back(P); + } + PadShape.push_back({0, 0}); + return PadShape; +} + +} // namespace + +mx::array nonOverlappingSlidingWindows(const mx::array &X, + const std::vector &Shape, + const std::vector &WindowShape) { + std::vector NewShape; + NewShape.push_back(Shape[0]); + for (size_t I = 1; I < std::min(Shape.size(), WindowShape.size() + 1); I++) { + int S = Shape[I]; + int W = WindowShape[I - 1]; + NewShape.push_back(S / W); + NewShape.push_back(W); + } + NewShape.push_back(Shape.back()); + int LastAxis = NewShape.size() - 1; + std::vector AxisOrder; + AxisOrder.push_back(0); + for (int I = 1; I < LastAxis; I += 2) { + AxisOrder.push_back(I); + } + for (int I = 2; I < LastAxis; I += 2) { + AxisOrder.push_back(I); + } + AxisOrder.push_back(LastAxis); + return transpose(reshape(X, NewShape), AxisOrder); +} + +mx::array slidingWindows(const mx::array &X, + const std::vector &WindowShape, + const std::vector &WindowStrides) { + if (X.ndim() < 3) { + spdlog::error( + "To extract sliding windows at least 1 spatial dimension (3 total) is " + "needed but the input only has " + + std::to_string(X.ndim()) + " dimensions."); + assumingUnreachable(); + } + std::vector Shape = X.shape(); + size_t SpatialDimsCount = Shape.size() - 2; + if (SpatialDimsCount != WindowShape.size() || + WindowShape.size() != WindowStrides.size()) { + + spdlog::error( + "The window shapes and strides must have the same number of spatial " + "dimensions as the signal."); + assumingUnreachable(); + } + bool UseNonOverlap = true; + for (size_t I = 0; I < SpatialDimsCount; I++) { + if (!(WindowShape[I] == WindowStrides[I] && + (Shape[I + 1] % WindowShape[I] == 0))) { + UseNonOverlap = false; + break; + } + } + if (UseNonOverlap) + return nonOverlappingSlidingWindows(X, Shape, WindowShape); + size_t N = Shape.size(); + std::vector BaseStrides(N); + BaseStrides[N - 1] = 1; + for (int I = N - 2; I >= 0; I--) { + BaseStrides[I] = Shape[I + 1] * BaseStrides[I + 1]; + } + std::vector FinalShape; + FinalShape.push_back(Shape[0]); + for (size_t I = 0; I < SpatialDimsCount; I++) { + int OutDim = (Shape[I + 1] - WindowShape[I]) / WindowStrides[I] + 1; + FinalShape.push_back(OutDim); + } + FinalShape.insert(FinalShape.end(), WindowShape.begin(), WindowShape.end()); + FinalShape.push_back(Shape.back()); + std::vector FinalStrides; + FinalStrides.push_back(BaseStrides[0]); + for (size_t I = 1; I < BaseStrides.size() - 1; I++) { + FinalStrides.push_back(BaseStrides[I] * WindowStrides[I - 1]); + } + for (size_t I = 1; I < BaseStrides.size(); I++) { + FinalStrides.push_back(BaseStrides[I]); + } + return mx::as_strided(X, FinalShape, FinalStrides, 0); +} + +Pool::Pool( + const std::function &)> + &PoolingFunction, + const std::vector &KernelSize, const std::vector &Stride, + const std::vector> &Padding, int PaddingValue) + : PoolingFunction(PoolingFunction), KernelSize(KernelSize), Stride(Stride), + Padding(Padding), PaddingValue(PaddingValue) { + int N = KernelSize.size(); + for (int I = -(N + 1); I < -1; I++) { + Axes.push_back(I); + } +} + +mx::array Pool::forward(const mx::array &X) { + mx::array Out = X; + bool NeedPad = false; + for (auto &P : Padding) { + if (P.first > 0) { + NeedPad = true; + break; + } + } + if (NeedPad) { + std::vector> PadShape = makePadShape(Padding); + Out = mx::pad(Out, PadShape, mx::array(PaddingValue)); + } + Out = slidingWindows(Out, KernelSize, Stride); + return PoolingFunction(Out, Axes); +} + +Pool2d::Pool2d( + const std::function &)> + &PoolingFn, + int PadValue, const std::vector &KernelSizes, + const std::optional> &StrideOpt, + const std::optional> &PaddingOpt) + : Pool(PoolingFn, + KernelSizes.size() == 1 ? valueOrList(KernelSizes[0], 2) + : KernelSizes, + (StrideOpt.has_value() + ? (StrideOpt.value().size() == 1 + ? valueOrList(StrideOpt.value()[0], 2) + : StrideOpt.value()) + : (KernelSizes.size() == 1 ? valueOrList(KernelSizes[0], 2) + : KernelSizes)), + makePaddingPairs(PaddingOpt.has_value() ? PaddingOpt.value() + : valueOrList(0, 2)), + PadValue) {} + +AvgPool2d::AvgPool2d(const std::vector &KernelSizes, + const std::optional> &StrideOpt, + const std::optional> &PaddingOpt) + : Pool2d( + [](const mx::array &A, const std::vector &Axis) -> mx::array { + return mx::mean(A, Axis, false); + }, + 0, KernelSizes, StrideOpt, PaddingOpt) {} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/pooling.h b/plugins/wasi_nn/MLX/mlx/pooling.h new file mode 100644 index 00000000..c5a20fc2 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/pooling.h @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "base.h" +namespace WasmEdge::Host::WASINN::MLX { + +namespace mlx::core::nn { +class Pool : public nn::Module { +public: + Pool(const std::function &)> &PoolingFunction, + const std::vector &KernelSize, const std::vector &Stride, + const std::vector> &Padding, int PaddingValue); + mx::array forward(const mx::array &X); + +protected: + std::function &)> + PoolingFunction; + std::vector KernelSize; + std::vector Stride; + std::vector> Padding; + int PaddingValue; + std::vector Axes; +}; + +class Pool2d : public Pool { +public: + Pool2d(const std::function &)> &PoolingFn, + int PadValue, const std::vector &KernelSizes, + const std::optional> &StrideOpt, + const std::optional> &PaddingOpt); +}; + +class AvgPool2d : public Pool2d { +public: + AvgPool2d(const std::vector &KernelSizes, + const std::optional> &StrideOpt = std::nullopt, + const std::optional> &PaddingOpt = std::nullopt); +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp new file mode 100644 index 00000000..635bd990 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/positional_encoding.h" + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array RoPE::forward(mx::array Input, int Offset) { + return mx::fast::rope(Input, Dims, Tranditional, Base, Scale, Offset); +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.h b/plugins/wasi_nn/MLX/mlx/positional_encoding.h new file mode 100644 index 00000000..95d1b676 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" + +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class RoPE : public Module { + int Dims; + bool Tranditional; + float Base; + float Scale; + +public: + RoPE(int Dims, bool Traditional = false, float Base = 10000, + float Scale = 1.0) + : Dims(Dims), Tranditional(Traditional), Base(Base), Scale(Scale) {} + + mx::array forward(mx::array Input, int Offset = 0); +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/quantized.cpp b/plugins/wasi_nn/MLX/mlx/quantized.cpp new file mode 100644 index 00000000..02de23a4 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/quantized.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/quantized.h" + +#include +#include + +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array QuantizedEmbedding::forward(mx::array Input) { + auto S = Input.shape(); + auto X = mx::flatten(Input); + auto Out = + mx::dequantize(take(Parameters.at("weight"), Input, 0), + take(Parameters.at("scales"), Input, 0), + take(Parameters.at("biases"), Input, 0), GroupSize, Bits); + return Out; +} + +mx::array QuantizedLinear::forward(mx::array Input) { + auto Out = mx::quantized_matmul( + Input, Parameters.at("weight"), Parameters.at("scales"), + Parameters.at("biases"), true, GroupSize, Bits); + if (Parameters.find("bias") != Parameters.end()) { + Out = add(Out, Parameters.at("bias")); + } + return Out; +} + +std::shared_ptr +QuantizedEmbedding::fromEmbedding(std::shared_ptr EmbeddingModule, + int GroupSize, int Bits) { + auto EmbeddingShape = EmbeddingModule->Parameters.at("weight").shape(); + auto QuantizedModel = std::make_shared(QuantizedEmbedding( + EmbeddingShape[0], EmbeddingShape[1], GroupSize, Bits)); + auto Quantized = + mx::quantize(EmbeddingModule->Parameters.at("weight"), GroupSize, Bits); + QuantizedModel->Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + QuantizedModel->Parameters.insert_or_assign( + "scales", std::move(std::get<1>(Quantized))); + QuantizedModel->Parameters.insert_or_assign( + "biases", std::move(std::get<2>(Quantized))); + return QuantizedModel; +} + +std::shared_ptr +QuantizedLinear::fromLinear(std::shared_ptr LinearModule, int GroupSize, + int Bits) { + auto LinearShape = LinearModule->Parameters.at("weight").shape(); + auto OutputDims = LinearShape[0]; + auto InputDims = LinearShape[1]; + const bool EnableBias = + LinearModule->Parameters.find("bias") != LinearModule->Parameters.end(); + auto QuantizedModel = std::make_shared( + QuantizedLinear(InputDims, OutputDims, EnableBias, GroupSize, Bits)); + auto Quantized = + mx::quantize(LinearModule->Parameters.at("weight"), GroupSize, Bits); + QuantizedModel->Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + QuantizedModel->Parameters.insert_or_assign( + "scales", std::move(std::get<1>(Quantized))); + QuantizedModel->Parameters.insert_or_assign( + "biases", std::move(std::get<2>(Quantized))); + if (EnableBias) { + QuantizedModel->Parameters.insert_or_assign( + "bias", LinearModule->Parameters.at("bias")); + } + return QuantizedModel; +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/quantized.h b/plugins/wasi_nn/MLX/mlx/quantized.h new file mode 100644 index 00000000..3bbe67fe --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/quantized.h @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" + +#include +#include + +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class QuantizedEmbedding : public Embedding { +public: + int GroupSize; + int Bits; + int NumEmbeddings; + int Dims; + + QuantizedEmbedding(int NumEmbeddings, int Dims, int GroupSize = 64, + int Bits = 4) + : GroupSize(GroupSize), Bits(Bits), NumEmbeddings(NumEmbeddings), + Dims(Dims) { + const double Scale = std::sqrt(1.0 / Dims); + registerParameter("weight", + mx::random::normal({NumEmbeddings, Dims}, 0.0, Scale)); + auto Quantized = mx::quantize(Parameters.at("weight"), GroupSize, Bits); + Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + registerParameter("scales", std::move(std::get<1>(Quantized))); + registerParameter("biases", std::move(std::get<2>(Quantized))); + } + + mx::array forward(mx::array Input) override; + + static std::shared_ptr + fromEmbedding(std::shared_ptr EmbeddingModule, int GroupSize = 64, + int Bits = 4); +}; + +class QuantizedLinear : public Linear { +public: + int GroupSize; + int Bits; + + QuantizedLinear(int InputDims, int OutputDim, bool Bias = true, + int GroupSize = 64, int Bits = 4) + : GroupSize(GroupSize), Bits(Bits) { + const double Scale = std::sqrt(1.0 / InputDims); + registerParameter( + "weight", mx::random::uniform(-Scale, Scale, {OutputDim, InputDims})); + auto Quantized = mx::quantize(Parameters.at("weight"), GroupSize, Bits); + Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + registerParameter("scales", std::move(std::get<1>(Quantized))); + registerParameter("biases", std::move(std::get<2>(Quantized))); + if (Bias) { + registerParameter("bias", mx::zeros({OutputDim})); + } + } + + mx::array forward(mx::array Input) override; + + static std::shared_ptr + fromLinear(std::shared_ptr LinearModule, int GroupSize = 64, + int Bits = 4); +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/transformer.cpp b/plugins/wasi_nn/MLX/mlx/transformer.cpp new file mode 100644 index 00000000..76f4d2b6 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/transformer.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/transformer.h" + +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array MultiHeadAttention::createAdditiveCausalMask(int N, mx::Dtype DType) { + auto Indices = mx::arange(N); + // mask = indices[:, None] < indices[None] + auto Mask = reshape(Indices, {N, 1}) < reshape(Indices, {1, N}); + // using 1e9 instead of INF, and softmax(full(1e9)) != nan + Mask = astype(Mask, DType) * -1e9; + return Mask; +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/transformer.h b/plugins/wasi_nn/MLX/mlx/transformer.h new file mode 100644 index 00000000..fc2c8410 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/transformer.h @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "mlx/linear.h" + +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class MultiHeadAttention : public Module { + int NumHeads; + +public: + MultiHeadAttention(int Dims, int NumHeads, + std::optional QueryInputDims = {}, + std::optional KeyInputDims = {}, + std::optional ValueInputDims = {}, + std::optional ValueDims = {}, + std::optional ValueOutputDims = {}, bool Bias = false) + : NumHeads(NumHeads) { + if (Dims % NumHeads != 0) { + spdlog::error( + "[WASI-NN] MLX backend: Dims must be divisible by NumHeads"sv); + assumingUnreachable(); + } + if (!QueryInputDims) { + QueryInputDims = Dims; + } + if (!KeyInputDims) { + KeyInputDims = Dims; + } + if (!ValueInputDims) { + ValueInputDims = KeyInputDims; + } + if (!ValueDims) { + ValueDims = Dims; + } + if (!ValueOutputDims) { + ValueOutputDims = Dims; + } + registerModule("query_proj", std::make_shared( + Linear(*QueryInputDims, Dims, Bias))); + registerModule("key_proj", + std::make_shared(Linear(*KeyInputDims, Dims, Bias))); + registerModule("value_proj", std::make_shared(Linear( + *ValueInputDims, *ValueDims, Bias))); + registerModule("out_proj", std::make_shared( + Linear(*ValueDims, *ValueOutputDims, Bias))); + }; + + mx::array forward(mx::array Queries, mx::array Keys, mx::array Values, + mx::array Mask); + + static mx::array createAdditiveCausalMask(int N, + mx::Dtype DType = mx::float32); +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/converter.cpp b/plugins/wasi_nn/MLX/model/converter.cpp new file mode 100644 index 00000000..4d20b36c --- /dev/null +++ b/plugins/wasi_nn/MLX/model/converter.cpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/converter.h" +#include "model/utils.h" + +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +std::unordered_map +weightsToMlx(std::string WeightPath) { + const std::filesystem::path Path(WeightPath); + if (std::filesystem::is_directory(Path)) { + std::unordered_map Loaded; + for (const auto &Entry : std::filesystem::directory_iterator(Path)) { + if (Entry.path().extension() == ".safetensors") { + auto SubWeight = weightsToMlx(Entry.path()); + Loaded.insert(SubWeight.begin(), SubWeight.end()); + } + } + return Loaded; + } + if (endsWith(WeightPath, ".safetensors")) { + spdlog::info( + "[WASI-NN] MLX backend: Loading model from .safetensors file...\n"sv); + const mx::SafetensorsLoad Loaded = mx::load_safetensors(WeightPath); + return Loaded.first; + } + if (endsWith(WeightPath, ".gguf")) { + spdlog::info("[WASI-NN] MLX backend: Loading model from .gguf file...\n"sv); + const mx::GGUFLoad Loaded = mx::load_gguf(WeightPath); + return Loaded.first; + } + spdlog::error("[WASI-NN] MLX backend: Can not regonize model file\n"sv); + assumingUnreachable(); +} + +std::unordered_map +llamaToMlxllm(std::string WeightPath) { + std::unordered_map ModelWeights; + auto Weight = weightsToMlx(WeightPath); + for (auto &[K, V] : Weight) { + std::string NewKey = K; + if (startsWith(NewKey, "model.")) { + strReplace(NewKey, "model.", ""); + } + std::vector SplitKey = splitString(NewKey, '.'); + if (find(SplitKey.begin(), SplitKey.end(), "layers") != SplitKey.end()) { + if (find(SplitKey.begin(), SplitKey.end(), "rotary_emb") != + SplitKey.end()) { + continue; + } + if (find(SplitKey.begin(), SplitKey.end(), "self_attn") != + SplitKey.end()) { + ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + ".attention." + + SplitKey[3] + "." + SplitKey[4], + V}); + } else if (find(SplitKey.begin(), SplitKey.end(), "mlp") != + SplitKey.end()) { + + ModelWeights.insert({NewKey, V}); + } else { + const std::unordered_map KeyMap = { + {"input_layernorm", "attention_norm"}, + {"post_attention_layernorm", "mlp_norm"}}; + if (KeyMap.find(SplitKey[2]) == KeyMap.end()) { + ModelWeights.insert({NewKey, V}); + } else { + ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + "." + + KeyMap.at(SplitKey[2]) + "." + SplitKey[3], + V}); + } + } + } else { + const std::unordered_map KeyMap = { + {"embed_tokens", "token_embed"}, + {"lm_head", "head"}, + {"norm", "norm"}}; + if (KeyMap.find(SplitKey[0]) == KeyMap.end()) { + ModelWeights.insert({NewKey, V}); + } else { + ModelWeights.insert({KeyMap.at(SplitKey[0]) + "." + SplitKey[1], V}); + } + } + } + return ModelWeights; +} + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/converter.h b/plugins/wasi_nn/MLX/model/converter.h new file mode 100644 index 00000000..bcf9a41b --- /dev/null +++ b/plugins/wasi_nn/MLX/model/converter.h @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" + +#include + +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +#define strReplace(Str, From, To) Str.replace(Str.find(From), strlen(From), To) + +std::unordered_map weightsToMlx(std::string WeightPath); + +std::unordered_map +llamaToMlxllm(std::string WeightPath); + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp b/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp new file mode 100644 index 00000000..e54efa39 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/gemma3/gemma3.h" +#include "language.h" +#include "mlx/embedding.h" +#include "vision.h" +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace gemma3 { + +ModelConfig ModelConfig::fromDict(const simdjson::dom::object &Obj) { + ModelConfig Config; + auto ModelTypeResult = Obj["model_type"].get_string(); + if (!ModelTypeResult.error()) { + Config.ModelType = std::string(ModelTypeResult.value()); + } + auto VocabResult = Obj["vocab_size"].get_int64(); + if (!VocabResult.error()) { + Config.VocabSize = static_cast(VocabResult.value()); + } + auto IgnoreResult = Obj["ignore_index"].get_int64(); + if (!IgnoreResult.error()) { + Config.IgnoreIndex = static_cast(IgnoreResult.value()); + } + auto ImageTokenResult = Obj["image_token_index"].get_int64(); + if (!ImageTokenResult.error()) { + Config.ImageTokenIndex = static_cast(ImageTokenResult.value()); + } + auto HiddenSizeResult = Obj["hidden_size"].get_int64(); + if (!HiddenSizeResult.error()) { + Config.HiddenSize = static_cast(HiddenSizeResult.value()); + } + auto PadTokenResult = Obj["pad_token_id"].get_int64(); + if (!PadTokenResult.error()) { + Config.PadTokenId = static_cast(PadTokenResult.value()); + } + return Config; +} + +Gemma3MultiModalProjector::Gemma3MultiModalProjector( + const ModelConfig &Config) { + registerModule("mm_soft_emb_norm", + std::make_shared(Config.VisionConfig.HiddenSize, + Config.VisionConfig.LayerNormEps)); + registerParameter( + "mm_input_projection_weight", + mx::ones({Config.VisionConfig.HiddenSize, Config.TextConfig.HiddenSize})); + PatchesPerImage = + Config.VisionConfig.ImageSize / Config.VisionConfig.PatchSize; + TokensPerSide = static_cast( + std::sqrt(static_cast(Config.TextConfig.MmTokensPerImage))); + KernelSize = PatchesPerImage / TokensPerSide; + AvgPool = + nn::AvgPool2d(std::vector{KernelSize}, std::vector{KernelSize}); +} + +mx::array Gemma3MultiModalProjector::forward(const mx::array &X) { + int B = X.shape()[0]; + int L = X.shape()[2]; + mx::array ReshapedVisionOutputs = transpose(X, {0, 2, 1}); + ReshapedVisionOutputs = + reshape(ReshapedVisionOutputs, {B, L, PatchesPerImage, PatchesPerImage}); + ReshapedVisionOutputs = transpose(ReshapedVisionOutputs, {0, 2, 3, 1}); + mx::array PooledVisionOutputs = AvgPool.forward(ReshapedVisionOutputs); + PooledVisionOutputs = transpose(PooledVisionOutputs, {0, 3, 1, 2}); + PooledVisionOutputs = flatten(PooledVisionOutputs, 2); + PooledVisionOutputs = transpose(PooledVisionOutputs, {0, 2, 1}); + mx::array NormedVisionOutputs = + std::dynamic_pointer_cast(Submodules["mm_soft_emb_norm"]) + ->forward(PooledVisionOutputs); + mx::array ProjectedVisionOutputs = mx::einsum( + "btm,md->btd", + std::vector( + {NormedVisionOutputs, Parameters.at("mm_input_projection_weight")})); + return astype(ProjectedVisionOutputs, X.dtype()); +} + +Model::Model(const ModelConfig &Config) : Config(Config) { + registerModule("vision_tower", + std::make_shared(Config.VisionConfig)); + registerModule("language_model", + std::make_shared(Config.TextConfig)); + registerModule("multi_modal_projector", + std::make_shared(Config)); + ModelType = Config.ModelType; +} + +std::pair +Model::getInputEmbeddings(const mx::array &InputIds, + const mx::array &PixelValues, const mx::array &Mask) { + if (PixelValues.size() == 0) { + mx::array Embeds = std::dynamic_pointer_cast( + Submodules["language_model"] + ->Submodules["model"] + ->Submodules["embed_tokens"]) + ->forward(InputIds); + return {Embeds, mx::array({})}; + } + mx::array InputsEmbeds = + std::dynamic_pointer_cast(Submodules["language_model"] + ->Submodules["model"] + ->Submodules["embed_tokens"]) + ->forward(InputIds); + mx::array HiddenState = mx::array({}), Temp1 = mx::array({}), + Temp2 = mx::array({}); + std::tie(HiddenState, Temp1, Temp2) = + std::dynamic_pointer_cast(Submodules["vision_tower"]) + ->forward(astype(transpose(PixelValues, {0, 2, 3, 1}), + InputsEmbeds.dtype()), + true); + auto NewShape = HiddenState.shape(); + NewShape.insert(NewShape.begin(), 1); + mx::array ImageFeatures = + astype(reshape(HiddenState, NewShape), PixelValues.dtype()); + ImageFeatures = std::dynamic_pointer_cast( + Submodules["multi_modal_projector"]) + ->forward(ImageFeatures); + return _prepareInputsForMultimodal(ImageFeatures, InputsEmbeds, InputIds, + Mask); +} + +std::pair Model::_prepareInputsForMultimodal( + const mx::array &ImageFeatures, const mx::array &InputsEmbeds, + const mx::array &InputIds, const mx::array &AttentionMask) { + int EmbedDim = ImageFeatures.shape().back(); + int BatchSize = InputIds.shape()[0]; + int SequenceLength = InputIds.shape()[1]; + mx::array ScaledImageFeatures = + ImageFeatures / std::pow(Config.HiddenSize, 0.5); + mx::array FinalEmbedding = mx::zeros({BatchSize, SequenceLength, EmbedDim}); + int PadTokenId = Config.PadTokenId; + mx::array TextMask = + (InputIds != Config.ImageTokenIndex) & (InputIds != PadTokenId); + mx::array ImageMask = (InputIds == Config.ImageTokenIndex); + mx::array PadMask = (InputIds == PadTokenId); + mx::array TextMaskExpanded = expand_dims(TextMask, -1); + TextMaskExpanded = repeat(TextMaskExpanded, EmbedDim, -1); + mx::array PadMaskExpanded = expand_dims(PadMask, -1); + PadMaskExpanded = repeat(PadMaskExpanded, EmbedDim, -1); + FinalEmbedding = mx::where(TextMaskExpanded, InputsEmbeds, FinalEmbedding); + FinalEmbedding = mx::where(PadMaskExpanded, mx::zeros_like(FinalEmbedding), + FinalEmbedding); + int PadSize = FinalEmbedding.shape()[1] - ScaledImageFeatures.shape()[1]; + ScaledImageFeatures = + mx::pad(ScaledImageFeatures, {{0, 0}, {0, PadSize}, {0, 0}}); + mx::array ImageMaskExpanded = expand_dims(ImageMask, -1); + ImageMaskExpanded = repeat(ImageMaskExpanded, EmbedDim, -1); + FinalEmbedding = + mx::where(ImageMaskExpanded, ScaledImageFeatures, FinalEmbedding); + FinalEmbedding = mx::where(PadMaskExpanded, mx::zeros_like(FinalEmbedding), + FinalEmbedding); + mx::array AttentionMaskExpanded1 = expand_dims(AttentionMask, 1); + mx::array AttentionMaskExpanded2 = expand_dims(AttentionMask, 2); + mx::array FinalAttentionMask4d = + AttentionMaskExpanded1 * AttentionMaskExpanded2; + FinalAttentionMask4d = expand_dims(FinalAttentionMask4d, 1); + return {FinalEmbedding, FinalAttentionMask4d}; +} + +std::tuple> Model::forward( + const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask, + const std::optional>> &Cache) { + auto Pair = getInputEmbeddings(InputIds, PixelValues, Mask); + mx::array InputEmbeds = Pair.first; + + // TODO: Waiting for upstream fix Mask + auto Logits = std::dynamic_pointer_cast( + Submodules["language_model"]) + ->forward(InputIds, InputEmbeds, std::nullopt, Cache); + return Logits; +} + +std::shared_ptr Model::fromPretrained(const std::string &ModelPath) { + std::filesystem::path Path(ModelPath); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto Error = Parser.load((Path / "config.json").string()).get(Doc); + if (Error) { + spdlog::error("Could not open config.json"); + assumingUnreachable(); + } + auto Obj = Doc.get_object(); + ModelConfig ModelConfigObj = ModelConfig::fromDict(Obj.value()); + ModelConfigObj.VisionConfig = + VisionConfig::fromDict(Obj["vision_config"].get_object().value()); + ModelConfigObj.TextConfig = + TextConfig::fromDict(Obj["text_config"].get_object().value()); + auto Model = std::make_shared(gemma3::Model(ModelConfigObj)); + std::vector WeightFiles; + for (auto &P : std::filesystem::directory_iterator(Path)) { + if (P.path().extension() == ".safetensors") + WeightFiles.push_back(P.path()); + } + if (WeightFiles.empty()) { + spdlog::error("[WASI-NN] MLX backend: No safetensors found in {}."sv, + Path.string()); + assumingUnreachable(); + } + std::unordered_map Weights; + for (auto &Wf : WeightFiles) { + auto W = mx::load_safetensors(Wf.string()); + Weights.insert(W.first.begin(), W.first.end()); + } + Weights = Model->sanitize(Weights); + Weights = gemma3::VisionModel(ModelConfigObj.VisionConfig).sanitize(Weights); + auto QuantResult = Obj["quantization"].get_object(); + if (!QuantResult.error()) { + auto GroupSize = + static_cast(QuantResult.value()["group_size"].get_int64()); + auto Bits = static_cast(QuantResult.value()["bits"].get_int64()); + spdlog::info( + "[WASI-NN] MLX backend: Quantizing model to {} bits, {} group size."sv, + Bits, GroupSize); + Model = std::dynamic_pointer_cast( + Model->toQuantized(GroupSize, Bits, "", Weights)); + } + Model->update(Weights); + return Model; +} + +std::unordered_map +Model::sanitize(const std::unordered_map &Weights) { + std::unordered_map Sanitized; + for (auto &Pair : Weights) { + std::string Key = Pair.first; + if (Key.find("vision_tower") == std::string::npos) { + size_t Pos = Key.find("vision_model"); + if (Pos != std::string::npos) + Key.replace(Pos, std::string("vision_model").length(), "vision_tower"); + } + if (Key.find("model") == 0) { + Key.replace(0, std::string("model").length(), ""); + } + Sanitized.insert({Key, Pair.second}); + } + return Sanitized; +} + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/gemma3.h b/plugins/wasi_nn/MLX/model/gemma3/gemma3.h new file mode 100644 index 00000000..77ba7346 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/gemma3.h @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "mlx/pooling.h" +#include "model/gemma3/language.h" +#include "model/gemma3/vision.h" +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { + +namespace nn = mlx::core::nn; +namespace gemma3 { + +struct ModelConfig { + TextConfig TextConfig; + VisionConfig VisionConfig; + std::string ModelType = "gemma3"; + int VocabSize = 262208; + int IgnoreIndex = -100; + int ImageTokenIndex = 262144; + int HiddenSize = 2048; + int PadTokenId = 0; + static ModelConfig fromDict(const simdjson::dom::object &Obj); +}; + +class Gemma3MultiModalProjector : public nn::Module { +public: + Gemma3MultiModalProjector(const ModelConfig &Config); + mx::array forward(const mx::array &X); + +private: + nn::AvgPool2d AvgPool = nn::AvgPool2d({0}); + int PatchesPerImage; + int TokensPerSide; + int KernelSize; +}; + +class Model : public vlm::Module { +public: + Model(const ModelConfig &Config); + std::pair + getInputEmbeddings(const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask); + std::pair _prepareInputsForMultimodal( + const mx::array &ImageFeatures, const mx::array &InputsEmbeds, + const mx::array &InputIds, const mx::array &AttentionMask); + std::tuple> forward( + const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask, + const std::optional>> &Cache = + std::nullopt) override; + static std::shared_ptr fromPretrained(const std::string &ModelPath); + std::unordered_map + sanitize(const std::unordered_map &Weights); + ModelConfig Config; + std::string ModelType; +}; + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/language.cpp b/plugins/wasi_nn/MLX/model/gemma3/language.cpp new file mode 100644 index 00000000..138b783a --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/language.cpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/gemma3/language.h" +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/positional_encoding.h" +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace gemma3 { + +TextConfig TextConfig::fromDict(const simdjson::dom::object &Obj) { + TextConfig Config; + auto SResult = Obj["model_type"].get_string(); + if (!SResult.error()) + Config.ModelType = std::string(SResult.value()); + auto IResult = Obj["hidden_size"].get_int64(); + if (!IResult.error()) + Config.HiddenSize = static_cast(IResult.value()); + IResult = Obj["num_hidden_layers"].get_int64(); + if (!IResult.error()) + Config.NumHiddenLayers = static_cast(IResult.value()); + IResult = Obj["intermediate_size"].get_int64(); + if (!IResult.error()) + Config.IntermediateSize = static_cast(IResult.value()); + IResult = Obj["num_attention_heads"].get_int64(); + if (!IResult.error()) + Config.NumAttentionHeads = static_cast(IResult.value()); + IResult = Obj["head_dim"].get_int64(); + if (!IResult.error()) + Config.HeadDim = static_cast(IResult.value()); + auto DResult = Obj["rms_norm_eps"].get_double(); + if (!DResult.error()) + Config.RmsNormEps = static_cast(DResult.value()); + IResult = Obj["vocab_size"].get_int64(); + if (!IResult.error()) + Config.VocabSize = static_cast(IResult.value()); + IResult = Obj["num_key_value_heads"].get_int64(); + if (!IResult.error()) + Config.NumKeyValueHeads = static_cast(IResult.value()); + DResult = Obj["rope_global_base_freq"].get_double(); + if (!DResult.error()) + Config.RopeGlobalBaseFreq = static_cast(DResult.value()); + DResult = Obj["rope_local_base_freq"].get_double(); + if (!DResult.error()) + Config.RopeLocalBaseFreq = static_cast(DResult.value()); + auto BResult = Obj["rope_traditional"].get_bool(); + if (!BResult.error()) + Config.RopeTraditional = BResult.value(); + DResult = Obj["query_pre_attn_scalar"].get_double(); + if (!DResult.error()) + Config.QueryPreAttnScalar = static_cast(DResult.value()); + IResult = Obj["sliding_window"].get_int64(); + if (!IResult.error()) + Config.SlidingWindow = static_cast(IResult.value()); + IResult = Obj["mm_tokens_per_image"].get_int64(); + if (!IResult.error()) + Config.MmTokensPerImage = static_cast(IResult.value()); + IResult = Obj["sliding_window_pattern"].get_int64(); + if (!IResult.error()) + Config.SlidingWindowPattern = static_cast(IResult.value()); + return Config; +} + +RMSNorm::RMSNorm(int Dims, float Eps) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); +} + +mx::array RMSNorm::forward(const mx::array &X) { + return mx::fast::rms_norm(X, + mx::array({1.0}, Parameters.at("weight").dtype()) + + Parameters.at("weight"), + Eps); +} + +Attention::Attention(const TextConfig &Config, int LayerIdx) + : NHeads(Config.NumAttentionHeads), NKVHeads(Config.NumKeyValueHeads), + Repeats(Config.NumAttentionHeads / Config.NumKeyValueHeads), + HeadDim(Config.HeadDim), LayerIdx(LayerIdx), + Scale(std::pow(Config.QueryPreAttnScalar, -0.5)), + QNorm(HeadDim, Config.RmsNormEps), KNorm(HeadDim, Config.RmsNormEps) { + registerModule("q_proj", std::make_shared( + Config.HiddenSize, NHeads * HeadDim, false)); + registerModule("k_proj", std::make_shared( + Config.HiddenSize, NKVHeads * HeadDim, false)); + registerModule("v_proj", std::make_shared( + Config.HiddenSize, NKVHeads * HeadDim, false)); + registerModule("o_proj", std::make_shared( + NHeads * HeadDim, Config.HiddenSize, false)); + registerModule("q_norm", + std::make_shared(HeadDim, Config.RmsNormEps)); + registerModule("k_norm", + std::make_shared(HeadDim, Config.RmsNormEps)); + IsSliding = ((LayerIdx + 1) % Config.SlidingWindowPattern) != 0; + registerModule("rope", std::make_shared( + HeadDim, Config.RopeTraditional, + (IsSliding ? Config.RopeLocalBaseFreq + : Config.RopeGlobalBaseFreq))); +} + +mx::array Attention::forward( + const mx::array &X, const std::optional &Mask, + const std::optional> &Cache) { + auto Shape = X.shape(); + int B = Shape[0], L = Shape[1]; + mx::array Queries = + std::dynamic_pointer_cast(Submodules["q_proj"])->forward(X); + mx::array Keys = + std::dynamic_pointer_cast(Submodules["k_proj"])->forward(X); + mx::array Values = + std::dynamic_pointer_cast(Submodules["v_proj"])->forward(X); + Queries = transpose(reshape(Queries, {B, L, NHeads, -1}), {0, 2, 1, 3}); + Keys = transpose(reshape(Keys, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Queries = std::dynamic_pointer_cast(Submodules["q_norm"]) + ->forward(Queries); + Keys = std::dynamic_pointer_cast(Submodules["k_norm"]) + ->forward(Keys); + if (Cache.has_value()) { + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries, Cache.value()->Offset); + Keys = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Keys, Cache.value()->Offset); + std::tie(Keys, Values) = Cache.value()->updateAndFetch(Keys, Values); + } else { + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries); + Keys = + std::dynamic_pointer_cast(Submodules["rope"])->forward(Keys); + } + if (Mask.has_value() && + Mask.value().shape().back() != Keys.shape().at(Keys.shape().size() - 2)) { + mx::array M = + take(Mask.value(), -Keys.shape().at(Keys.shape().size() - 2), -1); + return std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(transpose(reshape(mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, M), + {B, L, -1}), + {0, 2, 1})); + } + auto Output = (Mask.has_value() + ? mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask.value()) + : mx::fast::scaled_dot_product_attention(Queries, Keys, + Values, Scale)); + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(Output); +} + +MLP::MLP(int Dim, int HiddenDim) { + registerModule("gate_proj", + std::make_shared(Dim, HiddenDim, false)); + registerModule("down_proj", + std::make_shared(HiddenDim, Dim, false)); + registerModule("up_proj", + std::make_shared(Dim, HiddenDim, false)); +} + +mx::array MLP::forward(const mx::array &X) { + mx::array A = std::dynamic_pointer_cast(Submodules["gate_proj"]) + ->forward(X); + A = mlx::core::geluApprox(A); + mx::array B = + std::dynamic_pointer_cast(Submodules["up_proj"])->forward(X); + A = A * B; + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(A); +} + +TransformerBlock::TransformerBlock(const TextConfig &Config, int LayerIdx) + : NumAttentionHeads(Config.NumAttentionHeads), + HiddenSize(Config.HiddenSize) { + registerModule("self_attn", std::make_shared(Config, LayerIdx)); + registerModule( + "mlp", std::make_shared(Config.HiddenSize, Config.IntermediateSize)); + registerModule("input_layernorm", std::make_shared( + Config.HiddenSize, Config.RmsNormEps)); + registerModule( + "post_attention_layernorm", + std::make_shared(Config.HiddenSize, Config.RmsNormEps)); + registerModule( + "pre_feedforward_layernorm", + std::make_shared(Config.HiddenSize, Config.RmsNormEps)); + registerModule( + "post_feedforward_layernorm", + std::make_shared(Config.HiddenSize, Config.RmsNormEps)); +} + +mx::array TransformerBlock::forward( + const mx::array &Input, const std::optional &Mask, + const std::optional> &Cache) { + auto X = Input; + // Clip the input to avoid overflow in float16, but it make more memory usage + // if (X.dtype() == mx::bfloat16) { + // X = mx::clip(X, mx::array{-65504}, mx::array{65504}); + // } + mx::array R = std::dynamic_pointer_cast(Submodules["self_attn"]) + ->forward(std::dynamic_pointer_cast( + Submodules["input_layernorm"]) + ->forward(X), + Mask, Cache); + mx::array H = std::dynamic_pointer_cast( + Submodules["post_attention_layernorm"]) + ->forward(R); + // if (H.dtype() == mx::bfloat16) { + // H = mx::clip(astype(X, mx::float32) + astype(H, mx::float32), + // mx::array{-65504}, mx::array{65504}); + // } else { + H = X + H; + // } + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast( + Submodules["pre_feedforward_layernorm"]) + ->forward(H)); + auto Out = std::dynamic_pointer_cast( + Submodules["post_feedforward_layernorm"]) + ->forward(R); + // if (Out.dtype() == mx::bfloat16) { + // Out = clip(astype(H, mx::float32) + astype(Out, mx::float32), + // mx::array{-65504}, mx::array{65504}); + // } else { + Out = H + Out; + // } + return Out; +} + +Gemma3Model::Gemma3Model(const TextConfig &Config) : Config(Config) { + if (Config.VocabSize <= 0) { + assumingUnreachable(); + } + registerModule("embed_tokens", std::make_shared( + Config.VocabSize, Config.HiddenSize)); + for (int I = 0; I < Config.NumHiddenLayers; I++) { + Layers.push_back(std::make_shared(Config, I)); + } + registerLayer("layers", Layers); + registerModule("norm", std::make_shared(Config.HiddenSize, + Config.RmsNormEps)); +} + +mx::array Gemma3Model::forward( + const mx::array &Inputs, const std::optional &InputsEmbeds, + const std::optional &Mask, + const std::optional>> &Cache) { + mx::array H = + InputsEmbeds.has_value() + ? InputsEmbeds.value() + : std::dynamic_pointer_cast(Submodules["embed_tokens"]) + ->forward(Inputs); + H = H * astype(mx::array(std::pow(Config.HiddenSize, 0.5), mx::bfloat16), + H.dtype()); + std::vector> CacheValue = + Cache.has_value() ? Cache.value() + : std::vector>( + Config.NumHiddenLayers, nullptr); + std::optional FullMask = std::nullopt; + std::optional SlidingWindowMask = std::nullopt; + if (!Mask.has_value()) { + int J = Config.SlidingWindowPattern; + FullMask = vlm::createAttentionMask( + H, std::vector>( + CacheValue.begin() + J - 1, CacheValue.begin() + J)); + SlidingWindowMask = vlm::createAttentionMask(H, CacheValue); + } + for (size_t I = 0; I < Layers.size(); I++) { + bool IsGlobal = + (I % Config.SlidingWindowPattern == Config.SlidingWindowPattern - 1); + std::optional MaskLocal = std::nullopt; + if (!Mask.has_value() && IsGlobal) { + MaskLocal = FullMask; + } else if (!Mask.has_value()) { + MaskLocal = SlidingWindowMask; + } else { + MaskLocal = Mask.value(); + } + H = dynamic_cast(Layers[I].get()) + ->forward(H, MaskLocal, CacheValue[I]); + } + return std::dynamic_pointer_cast(Submodules["norm"]) + ->forward(H); +} + +LanguageModel::LanguageModel(const TextConfig &Config) : Config(Config) { + registerModule("model", std::make_shared(Config)); + registerModule("lm_head", std::make_shared( + Config.HiddenSize, Config.VocabSize, false)); +} + +std::tuple> LanguageModel::forward( + const mx::array &Inputs, const std::optional &InputsEmbeds, + const std::optional &Mask, + const std::optional>> &Cache) { + mx::array Out = std::dynamic_pointer_cast(Submodules["model"]) + ->forward(Inputs, InputsEmbeds, Mask, Cache); + Out = std::dynamic_pointer_cast(Submodules["lm_head"]) + ->forward(Out); + return std::tuple>{Out, {}}; +} + +std::unordered_map LanguageModel::sanitize( + const std::unordered_map &Weights) { + std::unordered_map Sanitized; + if (Weights.find("lm_head.weight") == Weights.end()) + Sanitized.insert({"language_model.lm_head.weight", + Weights.at("language_model.model.embed_tokens.weight")}); + for (auto &Pair : Weights) { + if (Pair.first.find("self_attn.rotary_emb.inv_freq") == std::string::npos) + Sanitized.insert({Pair.first, Pair.second}); + } + return Sanitized; +} + +int LanguageModel::headDim() const { return Config.HeadDim; } + +int LanguageModel::nKvHeads() const { return Config.NumKeyValueHeads; } +int LanguageModel::layers() const { return Config.NumHiddenLayers; } + +std::vector> LanguageModel::makeCache() { + std::vector> Caches; + for (int I = 0; I < Config.NumHiddenLayers; I++) { + if (I % Config.SlidingWindowPattern == Config.SlidingWindowPattern - 1) { + Caches.emplace_back(std::make_shared( + Config.HeadDim, Config.NumKeyValueHeads)); + } else { + Caches.emplace_back( + std::make_shared(Config.SlidingWindow, 0)); + } + } + return Caches; +} + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/language.h b/plugins/wasi_nn/MLX/model/gemma3/language.h new file mode 100644 index 00000000..483a17d4 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/language.h @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "../vlm_base.h" +#include "simdjson.h" +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace nn = mlx::core::nn; +namespace gemma3 { + +struct TextConfig { + std::string ModelType = "gemma3_text"; + int HiddenSize = 2560; + int NumHiddenLayers = 34; + int IntermediateSize = 10240; + int NumAttentionHeads = 8; + int HeadDim = 256; + float RmsNormEps = 1.0e-6; + int VocabSize = 262208; + int NumKeyValueHeads = 4; + float RopeGlobalBaseFreq = 1000000.0f; + float RopeLocalBaseFreq = 10000.0f; + bool RopeTraditional = false; + float QueryPreAttnScalar = 256; + int SlidingWindow = 1024; + std::optional< + std::unordered_map>>> + RopeScaling; + int MmTokensPerImage = 256; + int SlidingWindowPattern = 6; + static TextConfig fromDict(const simdjson::dom::object &Obj); +}; + +class RMSNorm : public nn::Module { +public: + RMSNorm(int Dims, float Eps = 1e-5); + mx::array forward(const mx::array &X); + +private: + float Eps; +}; + +class Attention : public nn::Module { +public: + Attention(const TextConfig &Config, int LayerIdx); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt, + const std::optional> + &Cache = std::nullopt); + +private: + int NHeads; + int NKVHeads; + int Repeats; + int HeadDim; + int LayerIdx; + float Scale; + RMSNorm QNorm; + RMSNorm KNorm; + bool IsSliding; +}; + +class MLP : public nn::Module { +public: + MLP(int Dim, int HiddenDim); + mx::array forward(const mx::array &X); +}; + +class TransformerBlock : public nn::Module { +public: + TransformerBlock(const TextConfig &Config, int LayerIdx); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt, + const std::optional> + &Cache = std::nullopt); + +private: + int NumAttentionHeads; + int HiddenSize; +}; + +class Gemma3Model : public nn::Module { +public: + Gemma3Model(const TextConfig &Config); + mx::array forward( + const mx::array &Inputs, + const std::optional &InputsEmbeds = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional>> &Cache = + std::nullopt); + std::vector> Layers; + TextConfig Config; +}; + +class LanguageModel : public vlm::LanguageModel { +public: + LanguageModel(const TextConfig &Config); + std::tuple> forward( + const mx::array &Inputs, + const std::optional>> &Cache = + std::nullopt) override { + return forward(Inputs, std::nullopt, std::nullopt, Cache); + } + std::tuple> forward( + const mx::array &Inputs, + const std::optional &InputsEmbeds = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional>> &Cache = + std::nullopt); + std::unordered_map + sanitize(const std::unordered_map &Weights); + int headDim() const override; + int nKvHeads() const override; + int layers() const override; + std::vector> makeCache() override; + TextConfig Config; +}; + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/vision.cpp b/plugins/wasi_nn/MLX/model/gemma3/vision.cpp new file mode 100644 index 00000000..ff6e64ce --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/vision.cpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/gemma3/vision.h" +#include "mlx/activations.h" +#include "mlx/convolution.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace gemma3 { + +VisionConfig VisionConfig::fromDict(const simdjson::dom::object &Obj) { + VisionConfig Config; + auto ModelTypeResult = Obj["model_type"].get_string(); + if (!ModelTypeResult.error()) { + Config.ModelType = std::string(ModelTypeResult.value()); + } + auto NumHiddenLayersResult = Obj["num_hidden_layers"].get_int64(); + if (!NumHiddenLayersResult.error()) { + Config.NumHiddenLayers = static_cast(NumHiddenLayersResult.value()); + } + auto HiddenSizeResult = Obj["hidden_size"].get_int64(); + if (!HiddenSizeResult.error()) { + Config.HiddenSize = static_cast(HiddenSizeResult.value()); + } + auto IntermediateSizeResult = Obj["intermediate_size"].get_int64(); + if (!IntermediateSizeResult.error()) { + Config.IntermediateSize = static_cast(IntermediateSizeResult.value()); + } + auto NumAttentionHeadsResult = Obj["num_attention_heads"].get_int64(); + if (!NumAttentionHeadsResult.error()) { + Config.NumAttentionHeads = + static_cast(NumAttentionHeadsResult.value()); + } + auto PatchSizeResult = Obj["patch_size"].get_int64(); + if (!PatchSizeResult.error()) { + Config.PatchSize = static_cast(PatchSizeResult.value()); + } + auto ImageSizeResult = Obj["image_size"].get_int64(); + if (!ImageSizeResult.error()) { + Config.ImageSize = static_cast(ImageSizeResult.value()); + } + auto NumChannelsResult = Obj["num_channels"].get_int64(); + if (!NumChannelsResult.error()) { + Config.NumChannels = static_cast(NumChannelsResult.value()); + } + auto LayerNormEpsResult = Obj["layer_norm_eps"].get_double(); + if (!LayerNormEpsResult.error()) { + Config.LayerNormEps = static_cast(LayerNormEpsResult.value()); + } + return Config; +} + +bool checkArrayShape(const mx::array &Arr) { + auto Shape = Arr.shape(); + if (Shape.size() != 4) + return false; + int OutChannels = Shape[0]; + int KH = Shape[1]; + int KW = Shape[2]; + return (OutChannels >= KH) && (OutChannels >= KW) && (KH == KW); +} + +VisionAttention::VisionAttention(int Dims, int NumHeads, + std::optional QueryInputDims, + std::optional KeyInputDims, + std::optional ValueInputDims, + std::optional ValueDims, + std::optional ValueOutputDims, bool Bias) + : NumHeads(NumHeads) { + if (Dims % NumHeads != 0) { + spdlog::error( + "[WASI-NN] MLX backend: Dims must be divisible by NumHeads"sv); + assumingUnreachable(); + } + int QInput = QueryInputDims.value_or(Dims); + int KInput = KeyInputDims.value_or(Dims); + int VInput = ValueInputDims.value_or(KInput); + int VDim = ValueDims.value_or(Dims); + int VOutDim = ValueOutputDims.value_or(Dims); + int HeadDim = Dims / NumHeads; + Scale = std::pow(HeadDim, -0.5); + registerModule("q_proj", std::make_shared(QInput, Dims, Bias)); + registerModule("k_proj", std::make_shared(KInput, Dims, Bias)); + registerModule("v_proj", std::make_shared(VInput, VDim, Bias)); + registerModule("out_proj", std::make_shared(VDim, VOutDim, Bias)); +} + +mx::array VisionAttention::forward(const mx::array &X, + const std::optional &Mask) { + mx::array Queries = + std::dynamic_pointer_cast(Submodules["q_proj"])->forward(X); + mx::array Keys = + std::dynamic_pointer_cast(Submodules["k_proj"])->forward(X); + mx::array Values = + std::dynamic_pointer_cast(Submodules["v_proj"])->forward(X); + int B = Queries.shape()[0]; + int L = Queries.shape()[1]; + Queries = transpose(reshape(Queries, {B, L, NumHeads, -1}), {0, 2, 1, 3}); + int S = Keys.shape()[1]; + Keys = transpose(reshape(Keys, {B, S, NumHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, S, NumHeads, -1}), {0, 2, 1, 3}); + mx::array Output = + Mask.has_value() ? mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask.value()) + : mx::fast::scaled_dot_product_attention(Queries, Keys, + Values, Scale); + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return std::dynamic_pointer_cast(Submodules["out_proj"]) + ->forward(Output); +} + +VisionMLP::VisionMLP(const VisionConfig &Config) { + registerModule("fc1", std::make_shared( + Config.HiddenSize, Config.IntermediateSize, true)); + registerModule("fc2", std::make_shared(Config.IntermediateSize, + Config.HiddenSize, true)); +} + +mx::array VisionMLP::forward(const mx::array &X) { + mx::array Out = + std::dynamic_pointer_cast(Submodules["fc1"])->forward(X); + Out = mlx::core::geluApprox(Out); + Out = std::dynamic_pointer_cast(Submodules["fc2"])->forward(Out); + return Out; +} + +EncoderLayer::EncoderLayer(const VisionConfig &Config) { + EmbedDim = Config.HiddenSize; + registerModule("self_attn", std::make_shared( + Config.HiddenSize, Config.NumAttentionHeads, + std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, true)); + registerModule("layer_norm1", std::make_shared( + EmbedDim, Config.LayerNormEps)); + registerModule("mlp", std::make_shared(Config)); + registerModule("layer_norm2", std::make_shared( + EmbedDim, Config.LayerNormEps)); +} + +mx::array EncoderLayer::forward(const mx::array &X, + const std::optional &Mask) { + mx::array R = + std::dynamic_pointer_cast(Submodules["self_attn"]) + ->forward(std::dynamic_pointer_cast( + Submodules["layer_norm1"]) + ->forward(X), + Mask); + mx::array H = X + R; + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast( + Submodules["layer_norm2"]) + ->forward(H)); + return H + R; +} + +Encoder::Encoder(const VisionConfig &Config) { + for (int I = 0; I < Config.NumHiddenLayers; I++) { + Layers.push_back(std::make_shared(Config)); + } + registerLayer("layers", Layers); +} + +std::pair> +Encoder::forward(const mx::array &X, + const std::optional &OutputHiddenStates, + const std::optional &Mask) { + std::vector EncoderStates; + if (OutputHiddenStates.has_value() && OutputHiddenStates.value()) + EncoderStates.push_back(X); + mx::array Out = X; + mx::array H = X; + for (auto &L : Layers) { + Out = L->forward(Out, Mask); + if (OutputHiddenStates.has_value() && OutputHiddenStates.value()) + EncoderStates.push_back(Out); + H = take(Out, 0, 0); + } + return {H, EncoderStates}; +} + +VisionEmbeddings::VisionEmbeddings(const VisionConfig &Config) + : Config(Config) { + EmbedDim = Config.HiddenSize; + ImageSize = Config.ImageSize; + PatchSize = Config.PatchSize; + registerModule("patch_embedding", std::make_shared(nn::Conv2d( + Config.NumChannels, EmbedDim, PatchSize, + {PatchSize, PatchSize}))); + NumPatches = (ImageSize / PatchSize) * (ImageSize / PatchSize); + NumPositions = NumPatches; + registerModule("position_embedding", + std::make_shared(NumPositions, EmbedDim)); +} + +mx::array VisionEmbeddings::forward(const mx::array &X) { + mx::array PatchEmbeddings = + std::dynamic_pointer_cast(Submodules["patch_embedding"]) + ->forward(X); + PatchEmbeddings = mx::flatten(PatchEmbeddings, 1, 2); + std::vector PositionIdsShapeVec; + for (int I = 0; I < NumPositions; I++) { + PositionIdsShapeVec.emplace_back(I); + } + mx::array PositionIds = + mx::array(PositionIdsShapeVec.data(), {1, NumPositions}); + mx::array Embeddings = PatchEmbeddings; + Embeddings = Embeddings + std::dynamic_pointer_cast( + Submodules["position_embedding"]) + ->forward(PositionIds); + return Embeddings; +} + +SigLipVisionModel::SigLipVisionModel(const VisionConfig &Config) { + registerModule("embeddings", std::make_shared(Config)); + registerModule("encoder", std::make_shared(Config)); + registerModule("post_layernorm", + std::make_shared(Config.HiddenSize)); +} + +std::tuple +SigLipVisionModel::forward(const mx::array &X, + const std::optional &OutputHiddenStates) { + mx::array Emb = + std::dynamic_pointer_cast(Submodules["embeddings"]) + ->forward(X); + auto EncoderOutputs = + std::dynamic_pointer_cast(Submodules["encoder"]) + ->forward(Emb, OutputHiddenStates, std::nullopt); + mx::array PoolerOutput = + std::dynamic_pointer_cast(Submodules["post_layernorm"]) + ->forward(std::get<0>(EncoderOutputs)); + return {PoolerOutput, Emb, std::get<1>(EncoderOutputs).back()}; +} + +VisionModel::VisionModel(const VisionConfig &Config) { + ModelType = Config.ModelType; + if (ModelType != "siglip_vision_model" && ModelType != "gemma3" && + ModelType != "gemma3_vision") { + spdlog::error("[WASI-NN] MLX backend: Unsupported model type: {}"sv, + ModelType); + assumingUnreachable(); + } + registerModule("vision_model", std::make_shared(Config)); +} + +std::tuple +VisionModel::forward(const mx::array &X, + const std::optional &OutputHiddenStates) { + return std::dynamic_pointer_cast( + Submodules["vision_model"]) + ->forward(X, OutputHiddenStates); +} + +std::unordered_map VisionModel::sanitize( + const std::unordered_map &Weights) { + std::unordered_map SanitizedWeights; + for (auto &Pair : Weights) { + if (Pair.first.find("patch_embedding.weight") != std::string::npos) { + if (checkArrayShape(Pair.second)) + SanitizedWeights.insert({Pair.first, Pair.second}); + else + SanitizedWeights.insert( + {Pair.first, transpose(Pair.second, {0, 2, 3, 1})}); + } else { + SanitizedWeights.insert({Pair.first, Pair.second}); + } + } + return SanitizedWeights; +} + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/vision.h b/plugins/wasi_nn/MLX/model/gemma3/vision.h new file mode 100644 index 00000000..91a421f6 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/vision.h @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "simdjson.h" +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace nn = mlx::core::nn; +namespace gemma3 { + +struct VisionConfig { + std::string ModelType = "siglip_vision_model"; + int NumHiddenLayers = 27; + int HiddenSize = 1152; + int IntermediateSize = 4304; + int NumAttentionHeads = 16; + int PatchSize = 14; + int ImageSize = 896; + int NumChannels = 3; + float LayerNormEps = 1e-6f; + static VisionConfig fromDict(const simdjson::dom::object &Obj); +}; + +bool checkArrayShape(const mx::array &Arr); + +class VisionAttention : public nn::Module { +public: + VisionAttention(int Dims, int NumHeads, + std::optional QueryInputDims = std::nullopt, + std::optional KeyInputDims = std::nullopt, + std::optional ValueInputDims = std::nullopt, + std::optional ValueDims = std::nullopt, + std::optional ValueOutputDims = std::nullopt, + bool Bias = true); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt); + +private: + int NumHeads; + float Scale; +}; + +class VisionMLP : public nn::Module { +public: + VisionMLP(const VisionConfig &Config); + mx::array forward(const mx::array &X); +}; + +class EncoderLayer : public nn::Module { +public: + EncoderLayer(const VisionConfig &Config); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt); + +private: + int EmbedDim; +}; + +class Encoder : public nn::Module { +public: + Encoder(const VisionConfig &Config); + std::pair> + forward(const mx::array &X, + const std::optional &OutputHiddenStates = std::nullopt, + const std::optional &Mask = std::nullopt); + +private: + std::vector> Layers; +}; + +class VisionEmbeddings : public nn::Module { +public: + VisionEmbeddings(const VisionConfig &Config); + mx::array forward(const mx::array &X); + +private: + VisionConfig Config; + int EmbedDim; + int ImageSize; + int PatchSize; + int NumPatches; + int NumPositions; +}; + +class SigLipVisionModel : public nn::Module { +public: + SigLipVisionModel(const VisionConfig &Config); + std::tuple + forward(const mx::array &X, + const std::optional &OutputHiddenStates = std::nullopt); +}; + +class VisionModel : public nn::Module { +public: + VisionModel(const VisionConfig &Config); + std::tuple + forward(const mx::array &X, + const std::optional &OutputHiddenStates = std::nullopt); + std::unordered_map + sanitize(const std::unordered_map &Weights); + +private: + std::string ModelType; +}; + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/llm/registry.cpp b/plugins/wasi_nn/MLX/model/llm/registry.cpp new file mode 100644 index 00000000..2761b5f8 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/llm/registry.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "registry.h" +#include "transformer.h" +namespace WasmEdge::Host::WASINN::MLX { +namespace llm { +std::shared_ptr llama38b(int VocabSize, float NormEps, + float RopeTheta, bool RopeTraditional) { + return std::make_shared(Transformer( + 4096, std::vector{14336}, VocabSize, 32, std::vector{32}, + std::vector{8}, NormEps, {}, RopeTraditional, RopeTheta)); +} + +std::shared_ptr llama27bChat(int VocabSize, float NormEps, + float RopeTheta, + bool RopeTraditional) { + return std::make_shared(Transformer( + 4096, std::vector{11008}, VocabSize, 32, std::vector{32}, + std::vector{32}, NormEps, {}, RopeTraditional, RopeTheta)); +} + +std::shared_ptr tinyLlama11BChatV10(int VocabSize, float NormEps, + float RopeTheta, + bool RopeTraditional) { + return std::make_shared(Transformer( + 2048, std::vector{5632}, VocabSize, 22, std::vector{32}, + std::vector{4}, NormEps, {}, RopeTraditional, RopeTheta)); +} + +} // namespace llm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/llm/registry.h b/plugins/wasi_nn/MLX/model/llm/registry.h new file mode 100644 index 00000000..359c7369 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/llm/registry.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "transformer.h" +namespace WasmEdge::Host::WASINN::MLX { +namespace llm { +std::shared_ptr llama38b(int VocabSize = 32000, + float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); + +std::shared_ptr llama27bChat(int VocabSize = 32000, + float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); + +std::shared_ptr tinyLlama11BChatV10(int VocabSize = 32000, + float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); + +} // namespace llm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/llm/transformer.cpp b/plugins/wasi_nn/MLX/model/llm/transformer.cpp new file mode 100644 index 00000000..e7ac5b94 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/llm/transformer.cpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/transformer.h" +#include "../utils.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "transformer.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace llm { + +mx::array RMSNorm::forward(mx::array Input) { + return mx::fast::rms_norm(Input, 1.0 + Parameters.at("weight"), Eps); +} + +std::tuple> +Attention::forward(mx::array Input, std::optional Mask, + std::optional> KVCache) { + const auto &[B, L, D] = + std::tie(Input.shape()[0], Input.shape()[1], Input.shape()[2]); + mx::array Queries = + std::dynamic_pointer_cast(Submodules["q_proj"]) + ->forward(Input); + mx::array Keys = std::dynamic_pointer_cast(Submodules["k_proj"]) + ->forward(Input); + mx::array Values = std::dynamic_pointer_cast(Submodules["v_proj"]) + ->forward(Input); + Queries = transpose(reshape(Queries, {B, L, NHeads, -1}), {0, 2, 1, 3}); + Keys = transpose(reshape(Keys, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + + if (NormQKProj) { + Queries = std::dynamic_pointer_cast(Submodules["q_norm"]) + ->forward(Queries); + Keys = std::dynamic_pointer_cast(Submodules["k_norm"]) + ->forward(Keys); + } + if (KVCache) { + const auto &[KeyCache, ValueCache] = *KVCache; + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries, KeyCache.shape(2)); + Keys = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Keys, KeyCache.shape(2)); + Keys = mx::concatenate({KeyCache, Keys}, 2); + Values = mx::concatenate({ValueCache, Values}, 2); + } else { + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries); + Keys = + std::dynamic_pointer_cast(Submodules["rope"])->forward(Keys); + } + mx::array Output = + Mask.has_value() ? mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask.value()) + : mx::fast::scaled_dot_product_attention(Queries, Keys, + Values, Scale); + + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return {std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(Output), + {Keys, Values}}; +} + +mx::array MLP::forward(mx::array Input) { + if (Gemma) { + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::gelu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); + } + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::silu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); +} + +std::tuple> +TransformerBlock::forward( + mx::array Input, std::optional Mask, + std::optional> KVCachePar) { + mx::array NormOutput({}); + if (!Gemma) { + NormOutput = + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); + } else { + NormOutput = + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); + } + auto [R, KVCache] = + std::dynamic_pointer_cast(Submodules["attention"]) + ->forward(NormOutput, Mask, KVCachePar); + auto H = Input + R; + if (!Gemma) { + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward( + std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); + } else { + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); + } + return {H + R, KVCache}; +} + +std::tuple>>> +Transformer::embed( + mx::array Input, + std::optional>> KVCachePar, + bool Norm) { + mx::array H = + std::dynamic_pointer_cast(Submodules["token_embed"]) + ->forward(Input); + if (Gemma) { + H = H * (pow(Dim, 0.5)); + } + std::optional Mask; + if (H.shape()[1] > 1) { + Mask = nn::MultiHeadAttention::createAdditiveCausalMask(H.shape()[1]); + Mask = astype(*Mask, H.dtype()); + } + std::vector> KVCache; + KVCache.reserve(Layers.size()); + for (size_t Idx = 0; Idx < Layers.size(); Idx++) { + std::tuple> Result = { + mx::array({}), {mx::array({}), mx::array({})}}; + if (KVCachePar) { + Result = Layers[Idx]->forward(H, Mask, (*KVCachePar)[Idx]); + } else { + Result = Layers[Idx]->forward(H, Mask, {}); + } + H = std::get<0>(Result); + KVCache.emplace_back(std::get<1>(Result)); + } + if (Norm) { + if (!Gemma) { + return {std::dynamic_pointer_cast(Submodules["norm"]) + ->forward(H), + KVCache}; + } + return {std::dynamic_pointer_cast(Submodules["norm"])->forward(H), + KVCache}; + } + return {H, KVCache}; +} + +std::tuple>>> +Transformer::forward( + mx::array Input, + std::optional>> KVCachePar) { + auto [X, KVCache] = embed(Input, KVCachePar, true); + mx::array Out({}); + if (EmbedAsHead) { + Out = std::dynamic_pointer_cast(Submodules["token_embed"]) + ->asLinear(X); + } else { + Out = std::dynamic_pointer_cast(Submodules["head"])->forward(X); + } + return {Out, KVCache}; +} + +std::tuple>>> +Transformer::stepGenerate(mx::array Input, std::optional Temp) { + // Reshape Input to input[:, None] + std::vector ReshapeDim = Input.shape(); + ReshapeDim.insert(ReshapeDim.begin(), 1); + auto [Logits, KVCache] = forward(reshape(Input, ReshapeDim)); + const int H = Logits.shape()[1] - 1; + // take logits[:, -1, :] + Logits = take(Logits, mx::array({H}), 1); + ReshapeDim = Logits.shape(); + ReshapeDim.erase(ReshapeDim.begin() + 1); + Logits = reshape(Logits, ReshapeDim); + mx::array Y({}); + if (Temp == 0) { + Y = mx::argmax(Logits, -1); + } else { + Y = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {Y, KVCache}; +} + +std::tuple>>> +Transformer::nextStepGenerate( + mx::array Y, std::optional Temp, + std::optional>> KVCachePar) { + // Reshape Y to y[:, None] + std::vector ReshapeDim = Y.shape(); + ReshapeDim.insert(ReshapeDim.begin() + 1, 1); + auto [Logits, KVCache] = forward(reshape(Y, ReshapeDim), KVCachePar); + Logits = squeeze(Logits, 1); + mx::array NextY({}); + if (Temp == 0) { + NextY = mx::argmax(Logits, -1); + } else { + NextY = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {NextY, KVCache}; +} + +enum AnserSataus { + STOP, + WAIT, + GO, +}; + +AnserSataus answerSataus(std::string Text, std::string End) { + if (endsWith(Text, End)) { + return STOP; + } + for (int Idx = 1; Idx < static_cast(End.size()); Idx++) { + if (endsWith(Text, End.substr(0, Idx))) { + return WAIT; + } + } + return GO; +} + +Transformer::LLMOutput +Transformer::generate(const std::string &Prompt, const BasePrompt &ModelPrompt, + const int MaxToken, const bool Verbose, + const std::unique_ptr &Tok) { + const std::vector Ids = Tok->Encode(Prompt); + mx::array Token = + mx::array(Ids.data(), {static_cast(Ids.size())}, mx::int32); + std::vector TokenList; + int TokenCount = 0; + int Skip = 0; + std::string Answer; + auto [Y, KVCache] = this->stepGenerate(Token, 0.1); + while (true) { + TokenCount++; + if (TokenCount > MaxToken) { + break; + } + eval(Y); + std::vector Tokens; + auto *Data = Y.data(); + for (int Idx = 0; Idx < static_cast(Y.size()); Idx++) { + Tokens.emplace_back(Data[Idx]); + } + // TODO: break when the token is the eos_token_id + TokenList.insert(TokenList.end(), Tokens.begin(), Tokens.end()); + Answer = Tok->Decode(TokenList); + const AnserSataus Status = answerSataus(Answer, ModelPrompt.TextEnd); + if (Status == STOP) { + break; + } + if (Status == GO) { + if (Verbose) { + spdlog::info("[WASI-NN] MLX backend: {}"sv, Answer.substr(Skip)); + } + Skip = Answer.size(); + } + auto [NY, NKVCache] = this->nextStepGenerate(Y, 0.1, KVCache); + Y = NY, KVCache = NKVCache; + } + return {Answer, TokenList}; +} + +} // namespace llm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/llm/transformer.h b/plugins/wasi_nn/MLX/model/llm/transformer.h new file mode 100644 index 00000000..1065e31a --- /dev/null +++ b/plugins/wasi_nn/MLX/model/llm/transformer.h @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/positional_encoding.h" +#include "prompt/prompt.h" +#include +#include +#include +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace nn = mlx::core::nn; +namespace llm { + +class RMSNorm : public nn::Module { + float Eps; + +public: + RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); + } + mx::array forward(mx::array Input); +}; +class Attention : public nn::Module { + + int NHeads; + int NKVHeads; + bool NormQKProj; + double Scale; + +public: + Attention(int Dim, int NHeads, int NKVHeads, + std::optional HeadDimPar = 0, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6) + : NHeads(NHeads), NKVHeads(NKVHeads), NormQKProj(NormQKProj) { + int HeadDim; + if (HeadDimPar) { + HeadDim = *HeadDimPar; + } else { + HeadDim = Dim / NHeads; + } + Scale = pow(HeadDim, -0.5); + registerModule("q_proj", std::make_shared( + nn::Linear(Dim, NHeads * HeadDim, false))); + registerModule("k_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("v_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("o_proj", std::make_shared( + nn::Linear(NHeads * HeadDim, Dim, false))); + + if (NormQKProj) { + registerModule("q_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); + registerModule("k_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); + } + float RopeScale; + if (RopeScaling && (*RopeScaling)["type"] == "linear") { + RopeScale = 1 / stof((*RopeScaling)["factor"]); + } else { + RopeScale = 1; + } + + registerModule("rope", + std::make_shared(nn::RoPE(HeadDim, RopeTraditional, + RopeTheta, RopeScale))); + } + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCache = {}); +}; +class MLP : public nn::Module { + bool Gemma; + +public: + MLP(int Dim, int HiddenDim, bool Gemma = false) : Gemma(Gemma) { + registerModule("gate_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); + registerModule("down_proj", std::make_shared( + nn::Linear(HiddenDim, Dim, false))); + registerModule("up_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); + } + mx::array forward(mx::array Input); +}; +class TransformerBlock : public nn::Module { + bool Gemma; + +public: + TransformerBlock(int Dim, int NHeads, int NKVHeads, int HiddenDim, + float NormEps, std::optional HeadDim = {}, + bool RopeTraditional = false, float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false) + : Gemma(Gemma) { + registerModule("attention", + std::make_shared(Attention( + Dim, NHeads, NKVHeads, HeadDim, RopeTraditional, + RopeTheta, RopeScaling, NormQKProj, AttentionNormEps))); + registerModule("mlp", std::make_shared(MLP(Dim, HiddenDim, Gemma))); + if (!Gemma) { + registerModule("attention_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + } else { + registerModule("attention_norm", + std::make_shared(RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(RMSNorm(Dim, NormEps))); + } + } + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCachePar = {}); +}; +class Transformer : public nn::Module { + int Dim; + std::optional> HiddenDim; + bool Gemma; + bool EmbedAsHead; + std::vector> Layers{}; + +public: + Transformer( + int Dim, std::optional> HiddenDim, int VocabSize, + int NLayers, std::optional> NHeads, + std::optional> NKVHeads = {}, float NormEps = 1e-5, + std::optional HeadDim = {}, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional>> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false, bool EmbedAsHeadPar = false) + : Dim(Dim), HiddenDim(HiddenDim), Gemma(Gemma), + EmbedAsHead(EmbedAsHeadPar) { + if (VocabSize <= 0) { + spdlog::error("VocabSize must be greater than 0."); + assumingUnreachable(); + } + EmbedAsHead = Gemma ? true : EmbedAsHead; + if (!NKVHeads) { + NKVHeads = NHeads; + } + registerModule("token_embed", std::make_shared( + nn::Embedding(VocabSize, Dim))); + if (HiddenDim->size() == 1) { + while (static_cast(HiddenDim->size()) < NLayers) { + HiddenDim->emplace_back((*HiddenDim)[0]); + } + } + if (NHeads->size() == 1) { + while (static_cast(NHeads->size()) < NLayers) { + NHeads->emplace_back((*NHeads)[0]); + } + } + if (NKVHeads->size() == 1) { + while (static_cast(NKVHeads->size()) < NLayers) { + NKVHeads->emplace_back((*NKVHeads)[0]); + } + } + Layers.reserve(NLayers); + for (int Idx = 0; Idx < NLayers; Idx++) { + if (RopeScaling) { + Layers.push_back(std::make_shared(TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, (*RopeScaling)[Idx], + NormQKProj, AttentionNormEps, Gemma))); + } else { + Layers.push_back(std::make_shared(TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, {}, NormQKProj, + AttentionNormEps, Gemma))); + } + } + registerLayer("layers", Layers); + if (!Gemma) { + registerModule("norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + } else { + registerModule("norm", std::make_shared(RMSNorm(Dim, NormEps))); + } + if (!EmbedAsHead) { + registerModule("head", std::make_shared( + nn::Linear(Dim, VocabSize, false))); + } + } + struct LLMOutput { + std::string Answer; + std::vector TokenList; + }; + std::tuple>>> + embed(mx::array Input, + std::optional>> + KVCachePar = {}, + bool Norm = false); + std::tuple>>> + forward(mx::array Input, + std::optional>> + KVCachePar = {}); + std::tuple>>> + stepGenerate(mx::array Input, std::optional Temp = 0.0); + std::tuple>>> + nextStepGenerate(mx::array Y, std::optional Temp = 0.0, + std::optional>> + KVCachePar = {}); + LLMOutput generate(const std::string &Prompt, const BasePrompt &ModelPrompt, + const int MaxToken, const bool Verbose, + const std::unique_ptr &Tok); +}; +} // namespace llm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.cpp b/plugins/wasi_nn/MLX/model/transformer.cpp new file mode 100644 index 00000000..79ca927c --- /dev/null +++ b/plugins/wasi_nn/MLX/model/transformer.cpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/transformer.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "model/transformer.h" +namespace WasmEdge::Host::WASINN::MLX { + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +mx::array RMSNorm::forward(mx::array Input) { + return mx::fast::rms_norm(Input, 1.0 + Parameters.at("weight"), Eps); +} + +std::tuple> +Attention::forward(mx::array Input, std::optional Mask, + std::optional> KVCache) { + const auto &[B, L, D] = + std::tie(Input.shape()[0], Input.shape()[1], Input.shape()[2]); + mx::array Queries = + std::dynamic_pointer_cast(Submodules["q_proj"]) + ->forward(Input); + mx::array Keys = std::dynamic_pointer_cast(Submodules["k_proj"]) + ->forward(Input); + mx::array Values = std::dynamic_pointer_cast(Submodules["v_proj"]) + ->forward(Input); + Queries = transpose(reshape(Queries, {B, L, NHeads, -1}), {0, 2, 1, 3}); + Keys = transpose(reshape(Keys, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + + if (NormQKProj) { + Queries = std::dynamic_pointer_cast(Submodules["q_norm"]) + ->forward(Queries); + Keys = std::dynamic_pointer_cast(Submodules["k_norm"]) + ->forward(Keys); + } + if (KVCache) { + const auto &[KeyCache, ValueCache] = *KVCache; + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries, KeyCache.shape(2)); + Keys = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Keys, KeyCache.shape(2)); + Keys = mx::concatenate({KeyCache, Keys}, 2); + Values = mx::concatenate({ValueCache, Values}, 2); + } else { + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries); + Keys = + std::dynamic_pointer_cast(Submodules["rope"])->forward(Keys); + } + mx::array Output = mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask); + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return {std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(Output), + {Keys, Values}}; +} + +mx::array MLP::forward(mx::array Input) { + if (Gemma) { + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::gelu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); + } + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::silu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); +} + +std::tuple> +TransformerBlock::forward( + mx::array Input, std::optional Mask, + std::optional> KVCachePar) { + mx::array NormOutput = {}; + if (!Gemma) { + NormOutput = + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); + } else { + NormOutput = + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); + } + auto [R, KVCache] = + std::dynamic_pointer_cast(Submodules["attention"]) + ->forward(NormOutput, Mask, KVCachePar); + auto H = Input + R; + if (!Gemma) { + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward( + std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); + } else { + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); + } + return {H + R, KVCache}; +} + +std::tuple>>> +Transformer::embed( + mx::array Input, + std::optional>> KVCachePar, + bool Norm) { + mx::array H = + std::dynamic_pointer_cast(Submodules["token_embed"]) + ->forward(Input); + if (Gemma) { + H = H * (pow(Dim, 0.5)); + } + std::optional Mask; + if (H.shape()[1] > 1) { + Mask = nn::MultiHeadAttention::createAdditiveCausalMask(H.shape()[1]); + Mask = astype(*Mask, H.dtype()); + } + std::vector> KVCache; + KVCache.reserve(Layers.size()); + for (size_t Idx = 0; Idx < Layers.size(); Idx++) { + std::tuple> Result = {{}, + {{}, {}}}; + if (KVCachePar) { + Result = Layers[Idx]->forward(H, Mask, (*KVCachePar)[Idx]); + } else { + Result = Layers[Idx]->forward(H, Mask, {}); + } + H = std::get<0>(Result); + KVCache.emplace_back(std::get<1>(Result)); + } + if (Norm) { + if (!Gemma) { + return {std::dynamic_pointer_cast(Submodules["norm"]) + ->forward(H), + KVCache}; + } + return {std::dynamic_pointer_cast(Submodules["norm"])->forward(H), + KVCache}; + } + return {H, KVCache}; +} + +std::tuple>>> +Transformer::forward( + mx::array Input, + std::optional>> KVCachePar) { + auto [X, KVCache] = embed(Input, KVCachePar, true); + mx::array Out = {}; + if (EmbedAsHead) { + Out = std::dynamic_pointer_cast(Submodules["token_embed"]) + ->asLinear(X); + } else { + Out = std::dynamic_pointer_cast(Submodules["head"])->forward(X); + } + return {Out, KVCache}; +} + +std::tuple>>> +Transformer::generate(mx::array Input, std::optional Temp) { + // Reshape Input to input[:, None] + std::vector ReshapeDim = Input.shape(); + ReshapeDim.insert(ReshapeDim.begin(), 1); + auto [Logits, KVCache] = forward(reshape(Input, ReshapeDim)); + const int H = Logits.shape()[1] - 1; + // take logits[:, -1, :] + Logits = take(Logits, mx::array({H}), 1); + ReshapeDim = Logits.shape(); + ReshapeDim.erase(ReshapeDim.begin() + 1); + Logits = reshape(Logits, ReshapeDim); + mx::array Y = {}; + if (Temp == 0) { + Y = mx::argmax(Logits, -1); + } else { + Y = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {Y, KVCache}; +} + +std::tuple>>> +Transformer::nextGenerate( + mx::array Y, std::optional Temp, + std::optional>> KVCachePar) { + // Reshape Y to y[:, None] + std::vector ReshapeDim = Y.shape(); + ReshapeDim.insert(ReshapeDim.begin() + 1, 1); + auto [Logits, KVCache] = forward(reshape(Y, ReshapeDim), KVCachePar); + Logits = squeeze(Logits, 1); + mx::array NextY = {}; + if (Temp == 0) { + NextY = mx::argmax(Logits, -1); + } else { + NextY = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {NextY, KVCache}; +} + +} // namespace WasmEdge::Host::WASINN::MLX + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h new file mode 100644 index 00000000..e555a8ea --- /dev/null +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/positional_encoding.h" +namespace WasmEdge::Host::WASINN::MLX { + +#include +#include +#include + +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +namespace nn = mlx::core::nn; + +class RMSNorm : public nn::Module { + float Eps; + +public: + RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); + } + + mx::array forward(mx::array Input); +}; + +class Attention : public nn::Module { + int NHeads; + int NKVHeads; + bool NormQKProj; + double Scale; + +public: + Attention(int Dim, int NHeads, int NKVHeads, + std::optional HeadDimPar = 0, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6) + : NHeads(NHeads), NKVHeads(NKVHeads), NormQKProj(NormQKProj) { + int HeadDim; + if (HeadDimPar) { + HeadDim = *HeadDimPar; + } else { + HeadDim = Dim / NHeads; + } + Scale = pow(HeadDim, -0.5); + registerModule("q_proj", std::make_shared( + nn::Linear(Dim, NHeads * HeadDim, false))); + registerModule("k_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("v_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("o_proj", std::make_shared( + nn::Linear(NHeads * HeadDim, Dim, false))); + + if (NormQKProj) { + registerModule("q_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); + registerModule("k_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); + } + float RopeScale; + if (RopeScaling && (*RopeScaling)["type"] == "linear") { + RopeScale = 1 / stof((*RopeScaling)["factor"]); + } else { + RopeScale = 1; + } + + registerModule("rope", + std::make_shared(nn::RoPE(HeadDim, RopeTraditional, + RopeTheta, RopeScale))); + } + + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCache = {}); +}; + +class MLP : public nn::Module { + bool Gemma; + +public: + MLP(int Dim, int HiddenDim, bool Gemma = false) : Gemma(Gemma) { + registerModule("gate_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); + registerModule("down_proj", std::make_shared( + nn::Linear(HiddenDim, Dim, false))); + registerModule("up_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); + } + + mx::array forward(mx::array Input); +}; + +class TransformerBlock : public nn::Module { + bool Gemma; + +public: + TransformerBlock(int Dim, int NHeads, int NKVHeads, int HiddenDim, + float NormEps, std::optional HeadDim = {}, + bool RopeTraditional = false, float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false) + : Gemma(Gemma) { + registerModule("attention", + std::make_shared(Attention( + Dim, NHeads, NKVHeads, HeadDim, RopeTraditional, + RopeTheta, RopeScaling, NormQKProj, AttentionNormEps))); + registerModule("mlp", std::make_shared(MLP(Dim, HiddenDim, Gemma))); + if (!Gemma) { + registerModule("attention_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + } else { + registerModule("attention_norm", + std::make_shared(RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(RMSNorm(Dim, NormEps))); + } + } + + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCachePar = {}); +}; + +class Transformer : public nn::Module { + int Dim; + std::optional> HiddenDim; + bool Gemma; + bool EmbedAsHead; + std::vector> Layers; + +public: + Transformer( + int Dim, std::optional> HiddenDim, int VocabSize, + int NLayers, std::optional> NHeads, + std::optional> NKVHeads = {}, float NormEps = 1e-5, + std::optional HeadDim = {}, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional>> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false, bool EmbedAsHeadPar = false) + : Dim(Dim), HiddenDim(HiddenDim), Gemma(Gemma), + EmbedAsHead(EmbedAsHeadPar) { + if (VocabSize <= 0) { + spdlog::error( + "[WASI-NN] MLX backend: VocabSize must be greater than 0."sv); + assumingUnreachable(); + } + EmbedAsHead = Gemma ? true : EmbedAsHead; + if (!NKVHeads) { + NKVHeads = NHeads; + } + registerModule("token_embed", std::make_shared( + nn::Embedding(VocabSize, Dim))); + if (HiddenDim->size() == 1) { + while (static_cast(HiddenDim->size()) < NLayers) { + HiddenDim->emplace_back((*HiddenDim)[0]); + } + } + if (NHeads->size() == 1) { + while (static_cast(NHeads->size()) < NLayers) { + NHeads->emplace_back((*NHeads)[0]); + } + } + if (NKVHeads->size() == 1) { + while (static_cast(NKVHeads->size()) < NLayers) { + NKVHeads->emplace_back((*NKVHeads)[0]); + } + } + Layers.reserve(NLayers); + for (int Idx = 0; Idx < NLayers; Idx++) { + if (RopeScaling) { + Layers.push_back(std::make_shared(TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, (*RopeScaling)[Idx], + NormQKProj, AttentionNormEps, Gemma))); + } else { + Layers.push_back(std::make_shared(TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, {}, NormQKProj, + AttentionNormEps, Gemma))); + } + } + registerLayer("layers", Layers); + if (!Gemma) { + registerModule("norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + } else { + registerModule("norm", std::make_shared(RMSNorm(Dim, NormEps))); + } + if (!EmbedAsHead) { + registerModule("head", std::make_shared( + nn::Linear(Dim, VocabSize, false))); + } + } + + std::tuple>>> + embed(mx::array Input, + std::optional>> + KVCachePar = {}, + bool Norm = false); + + std::tuple>>> + forward(mx::array Input, + std::optional>> + KVCachePar = {}); + + std::tuple>>> + generate(mx::array Input, std::optional Temp = 0.0); + + std::tuple>>> + nextGenerate(mx::array Y, std::optional Temp = 0.0, + std::optional>> + KVCachePar = {}); +}; + +} // namespace WasmEdge::Host::WASINN::MLX + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp new file mode 100644 index 00000000..4681b429 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/utils.h" +#include "host/wasi/vfs_io.h" +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +std::vector splitString(const std::string &S, char Delim) { + std::vector Result; + std::stringstream SS(S); + std::string Item; + while (std::getline(SS, Item, Delim)) { + Result.emplace_back(Item); + } + return Result; +} + +std::string joinString(std::vector &S, char Delim) { + std::string Result; + for (size_t Idx = 0; Idx < S.size(); Idx++) { + Result += S[Idx]; + if (Idx != S.size() - 1) { + Result += Delim; + } + } + return Result; +} + +bool endsWith(std::string const &Value, std::string const &Ending) { + if (Ending.size() > Value.size()) + return false; + return std::equal(Ending.rbegin(), Ending.rend(), Value.rbegin()); +} + +bool startsWith(std::string const &Value, std::string const &Starting) { + if (Starting.size() > Value.size()) + return false; + return std::equal(Starting.begin(), Starting.end(), Value.begin()); +} + +void saveWeights(const std::unordered_map &Weights, + const std::string Path) { + if (endsWith(Path, ".safetensors")) { + mx::save_safetensors(Path, Weights, {{"format", "mlx"}}); + } else { + spdlog::error("[WASI-NN] MLX backend: Unsupported file format"sv); + assumingUnreachable(); + } +} + +void saveWeights(const mx::array &Weights, const std::string &Path) { + if (endsWith(Path, ".npz")) { + mx::save(Path, Weights); + } else { + spdlog::error("[WASI-NN] MLX backend: Unsupported file format"sv); + assumingUnreachable(); + } +} + +std::string loadBytesFromFile(const std::string &Path, + const Host::WASI::Environ *Env) { + WasmEdge::FStream::IFStream Fs(Path, Env); + if (Fs.fail()) { + spdlog::error("[WASI-NN] MLX backend: Cannot open {}."sv, Path); + return ""; + } + std::string Data; + Fs.seekg(0, std::ios::end); + const size_t Size = static_cast(Fs.tellg()); + Fs.seekg(0, std::ios::beg); + Data.resize(Size); + Fs.read(Data.data(), Size); + return Data; +} + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h new file mode 100644 index 00000000..390413fd --- /dev/null +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" + +#include "host/wasi/vfs_io.h" +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +std::vector splitString(const std::string &S, char Delim); + +std::string joinString(std::vector &S, char Delim); + +bool endsWith(std::string const &Value, std::string const &Ending); + +bool startsWith(std::string const &Value, std::string const &Starting); + +void saveWeights(const std::unordered_map &Weights, + const std::string Path); + +void saveWeights(const mx::array &Weights, const std::string &Path); + +std::string loadBytesFromFile(const std::string &Path, + const Host::WASI::Environ *Env); + +void fillPlaceholders(std::ostringstream &Oss, const std::string &Fmt, + size_t &Pos); + +template std::string toString(const T &Value) { + std::ostringstream Oss; + Oss << Value; + return Oss.str(); +} + +template std::string toString(const std::vector &Vec) { + std::ostringstream Oss; + Oss << "["; + for (size_t I = 0; I < Vec.size(); I++) { + Oss << toString(Vec[I]); + if (I + 1 < Vec.size()) { + Oss << ", "; + } + } + Oss << "]"; + return Oss.str(); +} + +template +void fillPlaceholders(std::ostringstream &Oss, const std::string &Fmt, + size_t &Pos, T &&Value, Args &&...args) { + auto PlaceholderPos = Fmt.find("{}", Pos); + if (PlaceholderPos == std::string::npos) { + Oss << Fmt.substr(Pos); + return; + } + Oss << Fmt.substr(Pos, PlaceholderPos - Pos); + Oss << toString(Value); + Pos = PlaceholderPos + 2; + fillPlaceholders(Oss, Fmt, Pos, std::forward(args)...); +} + +template +std::string formatStr(const std::string &Fmt, Args &&...args) { + std::ostringstream Oss; + size_t Pos = 0; + fillPlaceholders(Oss, Fmt, Pos, std::forward(args)...); + return Oss.str(); +} + +template void debug(const std::string &fmt, Args &&...args) { + std::cout << "[DEBUG] " << formatStr(fmt, std::forward(args)...) + << std::endl; +} + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_base.cpp b/plugins/wasi_nn/MLX/model/vlm_base.cpp new file mode 100644 index 00000000..7e3ec447 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_base.cpp @@ -0,0 +1,538 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "vlm_base.h" +#include "common/errcode.h" +#include "model/vlm_sampling.h" +#include "spdlog/spdlog.h" +#include +#include +#include +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace vlm { + +// BaseCache implementation +std::vector BaseCache::getState() const { return {}; } + +void BaseCache::setState(const std::vector &State) { + if (!State.empty()) { + spdlog::error( + "[WASI-NN] MLX backend: This cache has no state but a state was set."sv); + assumingUnreachable(); + } +} + +std::string BaseCache::getMetaState() const { return ""; } + +void BaseCache::setMetaState(const std::string &Value) { + if (!Value.empty()) { + spdlog::error( + "[WASI-NN] MLX backend: This cache has no meta_state but a meta_state was set."sv); + assumingUnreachable(); + } +} + +bool BaseCache::isTrimmable() const { return false; } + +int BaseCache::trim(int) { return 0; } + +// KVCache implementation +KVCache::KVCache(int HeadDim, int NKVHeads, int Step) + : NKVHeads(NKVHeads), KHeadDim(HeadDim), VHeadDim(HeadDim), Step(Step) {} + +KVCache::KVCache(std::pair HeadDims, int NKVHeads, int Step) + : NKVHeads(NKVHeads), KHeadDim(HeadDims.first), VHeadDim(HeadDims.second), + Step(Step) {} + +std::tuple +KVCache::updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) { + update(NewKeys, NewValues); + return fetch(); +} + +std::tuple KVCache::fetch() const { + // self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + mx::array Indices = mx::arange(Offset); + mx::array KeysSlice = take(Keys, Indices, -2); + mx::array ValuesSlice = take(Values, Indices, -2); + return {KeysSlice, ValuesSlice}; +} + +void KVCache::update(const mx::array &NewKeys, const mx::array &NewValues) { + int Prev = Offset; + std::vector NewShape = NewKeys.shape(); + int NewLen = NewShape[2]; + + if (Keys.size() == 0 || (Prev + NewLen) > Keys.shape()[2]) { + int NSteps = (Step + NewLen - 1) / Step; + int NewCapacity = NSteps * Step; + std::vector KShape = {1, NKVHeads, NewCapacity, KHeadDim}; + std::vector VShape = {1, NKVHeads, NewCapacity, VHeadDim}; + mx::array NewK = mx::zeros(KShape, NewKeys.dtype()); + mx::array NewV = mx::zeros(VShape, NewValues.dtype()); + if (Keys.size() != 0) { + if (Prev % Step != 0) { + mx::array Indices = mx::arange(Prev); + Keys = take(Keys, Indices, -2); + Values = take(Values, Indices, -2); + } + Keys = mx::concatenate({Keys, NewK}, 2); + Values = mx::concatenate({Values, NewV}, 2); + } else { + Keys = NewK; + Values = NewV; + } + } + + Offset += NewLen; + // self.keys[..., prev : self.offset, :] = keys + // self.values[..., prev : self.offset, :] = values + auto End = NewKeys.shape(); + std::vector Start(End.size(), 0); + std::vector Stride(End.size(), 1); + Start[End.size() - 2] = Prev; + End[End.size() - 2] = Offset; + Keys = mx::slice_update(Keys, NewKeys, Start, End, Stride); + Values = mx::slice_update(Values, NewValues, Start, End, Stride); +} + +std::vector KVCache::getState() const { + if (Offset == Keys.shape()[2]) { + return {Keys, Values}; + } + mx::array Indices = mx::arange(Offset); + mx::array KeysSlice = take(Keys, Indices, -2); + mx::array ValuesSlice = take(Values, Indices, -2); + return {KeysSlice, ValuesSlice}; +} + +void KVCache::setState(const std::vector &State) { + if (State.size() != 2) { + spdlog::error( + "[WASI-NN] MLX backend: KVCache state must contain exactly two arrays"sv); + assumingUnreachable(); + } + Keys = State[0]; + Values = State[1]; + Offset = Keys.shape()[2]; +} + +bool KVCache::isTrimmable() const { return true; } + +int KVCache::trim(int N) { + N = std::min(Offset, N); + Offset -= N; + return N; +} + +// RotatingKVCache implementation +RotatingKVCache::RotatingKVCache(int MaxSize, int Keep, int StepSize) + : KVCache(0, 0, StepSize), Keep(Keep), MaxSize(MaxSize), Idx(0) {} + +mx::array RotatingKVCache::trim(int TrimSize, const mx::array &V, + std::optional Append) { + std::vector ToCat; + + if (TrimSize > 0) { + mx::array KeepIndices = mx::arange(Keep); + mx::array TrimIndices = mx::arange(TrimSize + Keep, V.shape()[2]); + mx::array KeepPart = take(V, KeepIndices, -2); + mx::array TrimPart = take(V, TrimIndices, -2); + ToCat = {KeepPart, TrimPart}; + } else { + ToCat = {V}; + } + + if (Append.has_value()) { + ToCat.push_back(Append.value()); + } + + return mx::concatenate(ToCat, 2); +} + +mx::array RotatingKVCache::temporalOrder(const mx::array &V) { + if (Idx == V.shape()[2]) { + return V; + } + if (Idx < Offset) { + mx::array KeepIndices = mx::arange(Keep); + mx::array IdxToEndIndices = mx::arange(Idx, V.shape()[2]); + mx::array KeepToIdxIndices = mx::arange(Keep, Idx); + + mx::array KeepPart = take(V, KeepIndices, -2); + mx::array IdxToEndPart = take(V, IdxToEndIndices, -2); + mx::array KeepToIdxPart = take(V, KeepToIdxIndices, -2); + + return mx::concatenate({KeepPart, IdxToEndPart, KeepToIdxPart}, 2); + } + mx::array IdxIndices = mx::arange(Idx); + return take(V, IdxIndices, -2); +} + +std::tuple +RotatingKVCache::updateConcat(const mx::array &NewKeys, + const mx::array &NewValues) { + if (Keys.size() == 0) { + Keys = NewKeys; + Values = NewValues; + } else { + // Put the keys/values in temporal order to preserve context + Keys = temporalOrder(Keys); + Values = temporalOrder(Values); + + // The largest size is MaxSize + S to ensure every token gets at least + // MaxSize context + int TrimSize = Idx - MaxSize; + Keys = trim(TrimSize, Keys, NewKeys); + Values = trim(TrimSize, Values, NewValues); + } + + Offset += NewKeys.shape()[2]; + Idx = Keys.shape()[2]; + return {Keys, Values}; +} + +std::tuple +RotatingKVCache::updateInPlace(const mx::array &NewKeys, + const mx::array &NewValues) { + // May not have hit the max size yet, so potentially keep growing the cache + std::vector KeysShape = NewKeys.shape(); + int B = KeysShape[0]; + int NKVHeads = KeysShape[1]; + int S = KeysShape[2]; + int KHeadDim = KeysShape[3]; + int VHeadDim = NewValues.shape()[3]; + + int Prev = Offset; + if (Keys.size() == 0 || + (Prev >= Keys.shape()[2] && Keys.shape()[2] < MaxSize)) { + int NewSize = std::min(Step, MaxSize - Prev); + std::vector KShape = {B, NKVHeads, NewSize, KHeadDim}; + std::vector VShape = {B, NKVHeads, NewSize, VHeadDim}; + + mx::array NewK = mx::zeros(KShape, NewKeys.dtype()); + mx::array NewV = mx::zeros(VShape, NewValues.dtype()); + + if (Keys.size() != 0) { + Keys = mx::concatenate({Keys, NewK}, 2); + Values = mx::concatenate({Values, NewV}, 2); + } else { + Keys = NewK; + Values = NewV; + } + Idx = Prev; + } + + // Trim if needed + int TrimSize = Keys.shape()[2] - MaxSize; + if (TrimSize > 0) { + Keys = trim(TrimSize, Keys); + Values = trim(TrimSize, Values); + Idx = MaxSize; + } + + // Rotate + if (Idx == MaxSize) { + Idx = Keep; + } + + // Assign + std::vector KeysEnd = Keys.shape(); + std::vector Start(KeysEnd.size(), 0); + std::vector ValuesEnd = Values.shape(); + std::vector Stride(KeysEnd.size(), 1); + Start[Start.size() - 2] = Idx; + KeysEnd[KeysEnd.size() - 2] = Idx + S; + ValuesEnd[KeysEnd.size() - 2] = Idx + S; + Keys = mx::slice_update(Keys, NewKeys, Start, KeysEnd, Stride); + Values = mx::slice_update(Values, NewValues, Start, ValuesEnd, Stride); + + Offset += S; + Idx += S; + + // If the buffer is not full, slice off the end + if (Offset < MaxSize) { + mx::array OffsetIndices = mx::arange(Offset); + mx::array KeysSlice = take(Keys, OffsetIndices, -2); + mx::array ValuesSlice = take(Values, OffsetIndices, -2); + return {KeysSlice, ValuesSlice}; + } + + return {Keys, Values}; +} + +std::tuple +RotatingKVCache::updateAndFetch(const mx::array &NewKeys, + const mx::array &NewValues) { + if (NewKeys.shape()[2] == 1) { + return updateInPlace(NewKeys, NewValues); + } + return updateConcat(NewKeys, NewValues); +} + +std::string RotatingKVCache::getMetaState() const { + return std::to_string(Keep) + "," + std::to_string(MaxSize) + "," + + std::to_string(Step) + "," + std::to_string(Offset) + "," + + std::to_string(Idx); +} + +void RotatingKVCache::setMetaState(const std::string &Value) { + std::stringstream SS(Value); + std::string Item; + std::vector Values; + + while (std::getline(SS, Item, ',')) { + Values.push_back(std::stoi(Item)); + } + + if (Values.size() == 5) { + Keep = Values[0]; + MaxSize = Values[1]; + Step = Values[2]; + Offset = Values[3]; + Idx = Values[4]; + } else { + spdlog::error("[WASI-NN] MLX backend: Invalid meta state format."sv); + assumingUnreachable(); + } +} + +bool RotatingKVCache::isTrimmable() const { return Offset < MaxSize; } + +int RotatingKVCache::trim(int N) { + N = std::min(Offset, N); + Offset -= N; + Idx -= N; + return N; +} + +mx::array createAdditiveCausalMask(int N, int Offset) { + auto Rinds = mx::arange(Offset + N); + mx::array Linds = mx::array({}); + if (Offset) { + Linds = mx::arange(Offset, Offset + N); + } else { + Linds = Rinds; + } + // mask = linds[:, None] < rinds[None] + return mx::less(mx::expand_dims(Linds, 1), mx::expand_dims(Rinds, 0)) * -1e9; +} + +std::optional createAttentionMask( + mx::array H, + std::optional>> Cache) { + int T = H.shape()[1]; + std::optional Mask = std::nullopt; + if (T > 1) { + int Offset = 0; + if (Cache.has_value() && Cache.value().size() > 0 && + Cache.value()[0] != nullptr) { + auto C = Cache.value()[0]; + auto RotCache = std::dynamic_pointer_cast(C); + if (RotCache) { + Offset = std::min(RotCache->MaxSize - 1, RotCache->Offset); + } else { + Offset = C->Offset; + } + } + Mask = createAdditiveCausalMask(T, Offset); + Mask = mx::astype(Mask.value(), H.dtype()); + } + return Mask; +} + +std::vector Module::generate( + const std::string &Prompt, std::optional Image, bool Verbose, + std::map> + Kwargs) { + + if (Verbose) { + spdlog::info("=========="sv); + if (Image.has_value()) { + spdlog::info("Files: {}\n"sv, Image.value()); + } else if (Kwargs.count("Video") > 0) { + /* Print video path */ + } + spdlog::info("Prompt: {}"sv, Prompt); + } + + std::string Text = ""; + // stream generate + mx::array InputIds = mx::array({}); + mx::array PixelValues = mx::array({}); + mx::array Mask = mx::array({}); + + // For pixel_values + // int ImageTokenIndex; + // auto ImageTokenIndexIt = Kwargs.find("image_token_index"); + // if (ImageTokenIndexIt != Kwargs.end()) { + // if (auto *ImageTokenIndexPtr = + // std::get_if(&ImageTokenIndexIt->second)) { + // ImageTokenIndex = *ImageTokenIndexPtr; + // } else { + // assumingUnreachable(); + // } + // } else { + // assumingUnreachable(); + // } + if (Kwargs.count("pixel_values") == 0) { + spdlog::error("Not implemented"); + assumingUnreachable(); + } else { + InputIds = *std::get_if(&Kwargs.find("input_ids")->second); + PixelValues = *std::get_if(&Kwargs.find("pixel_values")->second); + Mask = *std::get_if(&Kwargs.find("mask")->second); + } + // Generate_state + // Initialize cache + std::vector> Cache; + int MaxTokens = 256; + float Temperature = 0.0f; + std::optional RepetitionPenalty = std::nullopt; + size_t RepetitionContextSize = 20; + float TopP = 1.0f; + std::map LogitBias = {}; + auto LanguageModel = std::dynamic_pointer_cast( + this->Submodules["language_model"]); + + auto Sample = [&](mx::array Logits) -> std::tuple { + if (!LogitBias.empty()) { + for (const auto &[Index, Value] : LogitBias) { + Logits = + scatter_add_axis(Logits, mx::array({Index}), mx::array({Value}), 1); + } + } + + mx::array LogProbs = Logits - mx::logsumexp(Logits, -1); + mx::array Token = mx::array({}); + if (Temperature == 0.0f) { + Token = mx::argmax(Logits, -1); + } else { + if (TopP > 0 and TopP < 1.0) { + Token = topPSampling(Logits, TopP, Temperature); + } else { + Token = mx::random::categorical(Logits / Temperature); + } + } + + return {Token, LogProbs}; + }; + + if (RepetitionPenalty.has_value() && RepetitionPenalty < 0) { + spdlog::error("Repetition penalty must be greater than 0"); + assumingUnreachable(); + } + + auto MakeCache = LanguageModel->makeCache(); + Cache.insert(Cache.begin(), MakeCache.begin(), MakeCache.end()); + + // Initialize repetition context + auto FlatternInputIdsShape = reshape(InputIds, {-1}); + mx::async_eval(FlatternInputIdsShape); + std::vector RepetitionContext(FlatternInputIdsShape.data(), + FlatternInputIdsShape.data() + + InputIds.size()); + if (RepetitionContext.size() > RepetitionContextSize) { + RepetitionContext.erase(RepetitionContext.begin(), + RepetitionContext.end() - RepetitionContextSize); + } + + auto Step = [&](mx::array Y) -> std::tuple { + std::vector NewShape = Y.shape(); + NewShape.insert(NewShape.begin(), 1); + // TODO: handle decoder_input_ids + auto Outputs = std::dynamic_pointer_cast( + this->Submodules["language_model"]) + ->forward(reshape(Y, NewShape), Cache); + mx::array Logits = std::get<0>(Outputs); + Logits = take(Logits, Logits.shape()[1] - 1, 1); + mx::array LogProbs = mx::array({}); + if (RepetitionPenalty.has_value()) { + if (RepetitionContext.size() > 0) { + auto Indices = mx::array(RepetitionContext.data(), + {static_cast(RepetitionContext.size())}); + auto SelectedLogits = take(Logits, Indices, 1); + SelectedLogits = where(SelectedLogits < 0, + SelectedLogits * RepetitionPenalty.value(), + SelectedLogits / RepetitionPenalty.value()); + put_along_axis(Logits, Indices, SelectedLogits, 1); + } + std::tie(Y, LogProbs) = Sample(Logits); + RepetitionContext.emplace_back(Y.item()); + } else { + std::tie(Y, LogProbs) = Sample(Logits); + } + if (RepetitionContext.size() > RepetitionContextSize) { + RepetitionContext.erase(RepetitionContext.begin(), + RepetitionContext.end() - RepetitionContextSize); + } + return {Y, squeeze(LogProbs, 0)}; + }; + + // Perform the first step + auto Outputs = this->forward(InputIds, PixelValues, Mask, Cache); + mx::array Logits = std::get<0>(Outputs); + Logits = take(Logits, Logits.shape()[1] - 1, 1); + auto [Y, LogProbs] = Sample(Logits); + mx::async_eval(Y); + // TODO: handle cross_attention_states, encoder_outputs + // End generate_state + + std::optional Result; + GenerationResult LastResponse; + // auto Tic = std::chrono::system_clock::now().time_since_epoch(); + std::vector TokenList; + int N = 0; + while (true) { + // for statistic + // int PromptTPS; + // if (N == 0) { + // PromptTPS = + // std::chrono::duration_cast(Tic).count(); Tic = + // std::chrono::system_clock::now().time_since_epoch(); + // } + // TODO: if token = eos_token_id break + auto Response = GenerationResult(); + N++; + if (N >= MaxTokens) { + break; + } + auto [NextY, NextLogProbs] = Step(Y); + mx::async_eval(NextY); + // TODO: handle decoder_input_ids + auto Token = Y.item(); + if (Token == 1 || Token == 106) { + break; + } + TokenList.emplace_back(Token); + Y = NextY; + LogProbs = NextLogProbs; + } + return TokenList; + // end stream generate + // TODO: waiting for processor + if (Verbose) { + spdlog::info("\n=========="sv); + if (Text.empty()) { + spdlog::info("No text generated for this prompt"sv); + return {}; + // return Text; + } + + spdlog::info("Prompt: {} tokends, {} tokens-per-sec"sv, + LastResponse.PromptTokens, LastResponse.PromptTps); + spdlog::info("Generation: {} tokens {} tokens-per-sec"sv, + LastResponse.GenerationTokens, LastResponse.GenerationTps); + spdlog::info("Peak memory: {} GB"sv); + } + + // return Text; +} +} // namespace vlm + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_base.h b/plugins/wasi_nn/MLX/model/vlm_base.h new file mode 100644 index 00000000..45fc835f --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_base.h @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "mlx/base.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +namespace vlm { +class BaseCache { +public: + int Offset = 0; + + virtual ~BaseCache() = default; + + virtual std::tuple + updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) = 0; + + virtual std::vector getState() const; + virtual void setState(const std::vector &State); + + virtual std::string getMetaState() const; + virtual void setMetaState(const std::string &Value); + + virtual bool isTrimmable() const; + virtual int trim(int N); + virtual std::string getType() const { return "BaseCache"; } +}; + +class KVCache : public BaseCache { +public: + int NKVHeads; + int KHeadDim; + int VHeadDim; + mx::array Keys = mx::array({}); + mx::array Values = mx::array({}); + int Step; + + KVCache(int HeadDim, int NKVHeads, int Step = 256); + KVCache(std::pair HeadDims, int NKVHeads, int Step = 256); + virtual std::tuple + updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) override; + + std::tuple fetch() const; + + void update(const mx::array &NewKeys, const mx::array &NewValues); + + std::vector getState() const override; + void setState(const std::vector &State) override; + + bool isTrimmable() const override; + int trim(int N) override; + std::string getType() const override { return "KVCache"; } +}; + +class RotatingKVCache : public KVCache { +public: + int Keep; + int MaxSize; + int Idx; + + RotatingKVCache(int MaxSize = -1, int Keep = 0, int StepSize = 256); + + std::tuple + updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) override; + + std::tuple updateInPlace(const mx::array &NewKeys, + const mx::array &NewValues); + + std::tuple updateConcat(const mx::array &NewKeys, + const mx::array &NewValues); + + mx::array trim(int TrimSize, const mx::array &V, + std::optional Append = std::nullopt); + + mx::array temporalOrder(const mx::array &V); + + std::string getMetaState() const override; + void setMetaState(const std::string &Value) override; + + bool isTrimmable() const override; + int trim(int N) override; + std::string getType() const override { return "RotatingKVCache"; } +}; + +class Module : public mlx::core::nn::Module { +public: + virtual std::tuple> forward( + const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask, + const std::optional>> &Cache = + std::nullopt) = 0; + struct GenerationResult { + std::string Text; + int Token; + std::vector LogProbs; + int PromptTokens; + int GenerationTokens; + float PromptTps; + float GenerationTps; + float PeakMemory; + }; + + // Add this struct to hold generation state + struct StreamGenerationState { + void *Model; + void *Processor; + mx::array PromptTokens; + mx::array InputIds; + mx::array PixelValues; + mx::array Mask; + mx::array CurrentToken; + std::vector CurrentLogProbs; + int TokenCount; + double StartTime; + double PromptTime; + float PromptTps; + bool IsComplete; + std::map Kwargs; + + // Generation parameters + int MaxTokens = 256; + float Temperature = 0.0; + std::optional RepetitionPenalty = std::nullopt; + std::optional RepetitionContextSize = 20; + float TopP = 1.0; + std::map LogitBias; + + // State for generate_step + std::vector RepetitionContext; + std::vector Cache; // Will hold appropriate cache objects + mx::array CrossAttentionStates; + mx::array EncoderOutputs; + }; + + std::vector generate( + const std::string &Prompt = {}, + std::optional Image = std::nullopt, bool Verbose = false, + std::map> + Kwargs = {}); +}; +class LanguageModel : public mlx::core::nn::Module { +public: + virtual int headDim() const = 0; + virtual int nKvHeads() const = 0; + virtual int layers() const = 0; + virtual std::vector> makeCache() { + std::vector> Cache; + int HeadDim = headDim(); + auto KVHeads = nKvHeads(); + for (int I = 0; I < layers(); ++I) { + Cache.emplace_back(std::make_shared(HeadDim, KVHeads)); + } + return Cache; + } + virtual std::tuple> forward( + const mx::array &Inputs, + const std::optional>> &Cache = + std::nullopt) = 0; +}; + +std::optional createAttentionMask( + mx::array H, + std::optional>> = std::nullopt); + +mx::array createAdditiveCausalMask(int N, int Offset = 0); + +} // namespace vlm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_sampling.cpp b/plugins/wasi_nn/MLX/model/vlm_sampling.cpp new file mode 100644 index 00000000..c4f9f41c --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_sampling.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/vlm_sampling.h" +namespace WasmEdge::Host::WASINN::MLX { +namespace vlm { + +mx::array topPSampling(const mx::array &Logits, float TopP, float Temperature) { + + mx::array WorkingLogits = Logits; + + if (WorkingLogits.dtype() == mx::bfloat16) { + WorkingLogits = astype(WorkingLogits, mx::float32); + } + + mx::array Probs = mx::softmax(WorkingLogits / Temperature, -1); + mx::array SortedIndices = mx::argsort(Probs, -1); + mx::array SqueezedIndices = mx::squeeze(SortedIndices, 0); + mx::array SortedProbs = mx::take(Probs, SqueezedIndices, -1); + mx::array CumulativeProbs = mx::cumsum(SortedProbs, -1); + mx::array TopProbs = mx::where(CumulativeProbs > 1.0f - TopP, SortedProbs, + mx::zeros_like(SortedProbs)); + mx::array SortedToken = mx::random::categorical(mx::log(TopProbs)); + mx::array Token = mx::take(SqueezedIndices, SortedToken); + + return Token; +} + +} // namespace vlm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_sampling.h b/plugins/wasi_nn/MLX/model/vlm_sampling.h new file mode 100644 index 00000000..a5fb1592 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_sampling.h @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "mlx/base.h" +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace vlm { + +mx::array topPSampling(const mx::array &Logits, float TopP, float Temperature); + +} // namespace vlm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp new file mode 100644 index 00000000..a387f1ba --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp @@ -0,0 +1,884 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "decoding.h" +#include "mlx/base.h" +#include "tokenizer.h" +#include "whisper.h" +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +// Utility function implementation +float compressionRatio(const std::string &Text) { + if (Text.empty()) + return 1.0f; + + std::vector Input(Text.begin(), Text.end()); + uLongf CompressedSize = compressBound(Input.size()); + std::vector Compressed(CompressedSize); + + int Result = + compress(Compressed.data(), &CompressedSize, Input.data(), Input.size()); + + if (Result != Z_OK) + return 1.0f; + + return static_cast(Input.size()) / static_cast(CompressedSize); +} + +// Language detection implementation +std::pair>> +detectLanguage(std::shared_ptr Model, const mx::array &Mel, + std::shared_ptr Tokenizer) { + + if (!Tokenizer) { + Tokenizer = getTokenizer(Model->isMultilingual(), Model->numLanguages()); + } + + if (!Tokenizer->Language || + std::find(Tokenizer->SotSequence.begin(), Tokenizer->SotSequence.end(), + Tokenizer->languageToken()) == Tokenizer->SotSequence.end()) { + throw std::runtime_error( + "This model doesn't have language tokens so it can't perform lang id"); + } + bool Single = Mel.ndim() == 2; + mx::array MelArray = Single ? mx::expand_dims(Mel, 0) : Mel; + // Skip encoder forward pass if already-encoded audio features were given + if (MelArray.shape(-2) != Model->Dims.NAudioCtx || + MelArray.shape(-1) != Model->Dims.NAudioState) { + MelArray = Model->embedAudio(MelArray); + } + // Forward pass using a single token, start of transcript + int NAudio = MelArray.shape(0); + std::vector> TokensVec(NAudio, {Tokenizer->getSot()}); + mx::array Tokens = mx::array(TokensVec[0].data(), {NAudio, 1}, mx::int32); + + mx::array Logits = Model->logits(Tokens, MelArray); + Logits = mx::take(Logits, mx::array({0}), 1); // [:, 0] + // Collect detected languages; suppress all non-language tokens + mx::array MaskArray = mx::full( + {Logits.shape(-1)}, -std::numeric_limits::infinity(), mx::float32); + auto LangTokens = Tokenizer->getAllLanguageTokens(); + mx::array LangTokensArray = mx::array( + LangTokens.data(), {static_cast(LangTokens.size())}, mx::int32); + MaskArray = + mx::scatter(MaskArray, LangTokensArray, + mx::zeros({static_cast(LangTokens.size()), 1}), 0); + + Logits = Logits + MaskArray; + mx::array LanguageTokens = mx::argmax(Logits, -1); + mx::array LanguageTokenProbs = mx::softmax(Logits, -1); + LanguageTokenProbs = mx::take(LanguageTokenProbs, 0, 0); + + std::vector> LanguageProbs; + auto LangCodes = Tokenizer->getAllLanguageCodes(); + + for (int I = 0; I < NAudio; ++I) { + std::map Probs; + for (size_t J = 0; J < LangTokens.size() && J < LangCodes.size(); ++J) { + int TokenId = LangTokens[J]; + std::string LangCode = LangCodes[J]; + + mx::array ProbArray = + mx::take(mx::take(LanguageTokenProbs, mx::array({0}), 0), + mx::array({TokenId}), -1); + ProbArray = mx::squeeze(ProbArray); + + float Prob = ProbArray.item(); + Probs[LangCode] = Prob; + } + LanguageProbs.push_back(Probs); + } + + if (Single) { + LanguageTokens = mx::take(LanguageTokens, mx::array({0}), 0); + LanguageProbs = {LanguageProbs[0]}; + } + + return {LanguageTokens, LanguageProbs}; +} + +// Inference class implementation +Inference::Inference(std::shared_ptr Model) : Model(Model) { reset(); } + +mx::array Inference::logits(const mx::array &Tokens, + const mx::array &AudioFeatures) { + auto [LogitsOutput, NewKvCache, _] = + std::dynamic_pointer_cast(Model->Submodules.at("decoder")) + ->forward(Tokens, AudioFeatures, KvCache); + KvCache = NewKvCache; + return mx::astype(LogitsOutput, mx::float32); +} + +void Inference::rearrangeKvCache(const std::vector &SourceIndices) { + // TODO: Implement KV cache rearrangement for beam search + assumingUnreachable(); +} + +void Inference::reset() { KvCache = std::nullopt; } + +// GreedyDecoder implementation +GreedyDecoder::GreedyDecoder(float Temperature, int Eot) + : Temperature(Temperature), Eot(Eot) {} + +void GreedyDecoder::reset() { + // Nothing to reset for greedy decoder +} + +std::tuple +GreedyDecoder::update(const mx::array &Tokens, const mx::array &Logits, + const mx::array &SumLogprobs) { + + // Sample next tokens + mx::array NextTokens = mx::array({}); + if (Temperature == 0.0f) { + NextTokens = mx::argmax(Logits, -1); + } else { + NextTokens = mx::random::categorical(Logits / Temperature); + } + + // Compute logprobs + mx::array Logprobs = Logits - mx::logsumexp(Logits, -1, false); + mx::array CurrentLogprobs = + mx::take(Logprobs, mx::arange(Logprobs.shape(0)), 0); + CurrentLogprobs = take(CurrentLogprobs, NextTokens, 1); + mx::array NewSumLogprobs = + SumLogprobs + + CurrentLogprobs * (take(Tokens, Tokens.shape(1) - 1, 1) != Eot); + + // Extend tokens + mx::array NewTokens = + mx::concatenate({Tokens, mx::expand_dims(NextTokens, -1)}, -1); + + // Check if all sequences are complete + bool Completed = mx::all(mx::equal(mx::take(NewTokens, mx::array({-1}), 1), + mx::array({Eot}))) + .item(); + + return {NewTokens, Completed, NewSumLogprobs}; +} + +std::pair +GreedyDecoder::finalize(const mx::array &Tokens, const mx::array &SumLogprobs) { + + // Make sure each sequence has at least one EOT token at the end + std::vector> PadWidths = {{0, 0}, {0, 0}, {0, 1}}; + mx::array PaddedTokens = + mx::pad(Tokens, PadWidths, mx::array(Eot, mx::int32)); + + return {PaddedTokens, SumLogprobs}; +} + +// SuppressBlank implementation +SuppressBlank::SuppressBlank(std::shared_ptr Tokenizer, + int SampleBegin, int NVocab) + : SampleBegin(SampleBegin), Mask(mx::zeros({NVocab}, mx::float32)) { + Name = "SuppressBlank"; + std::vector MaskVec(NVocab, 0.0f); + + // Suppress space and EOT tokens + auto SpaceTokens = Tokenizer->encode(" "); + for (int Token : SpaceTokens) { + if (Token >= 0 && Token < NVocab) { + MaskVec[Token] = -std::numeric_limits::infinity(); + } + } + + int EotToken = Tokenizer->getEot(); + if (EotToken < NVocab) { + MaskVec[EotToken] = -std::numeric_limits::infinity(); + } + + Mask = mx::array(MaskVec.data(), {NVocab}, mx::float32); +} + +mx::array SuppressBlank::apply(const mx::array &Logits, + const mx::array &Tokens) { + if (Tokens.shape(1) == SampleBegin) { + return Logits + Mask; + } + return Logits; +} + +// SuppressTokens implementation +SuppressTokens::SuppressTokens(const std::vector &SuppressTokens, + int NVocab) + : Mask(mx::zeros({NVocab}, mx::float32)) { + Name = "SuppressTokens"; + std::vector MaskVec(NVocab, 0.0f); + for (int Token : SuppressTokens) { + MaskVec[Token] = -std::numeric_limits::infinity(); + } + + Mask = mx::array(MaskVec.data(), {NVocab}, mx::float32); +} + +mx::array SuppressTokens::apply(const mx::array &Logits, + const mx::array &Tokens) { + return Logits + Mask; +} + +// ApplyTimestampRules implementation +ApplyTimestampRules::ApplyTimestampRules( + std::shared_ptr Tokenizer, int SampleBegin, + std::optional MaxInitialTimestampIndex) + : Tokenizer(Tokenizer), SampleBegin(SampleBegin), + MaxInitialTimestampIndex(MaxInitialTimestampIndex) { + Name = "ApplyTimestampRules"; +} + +mx::array ApplyTimestampRules::apply(const mx::array &Logits, + const mx::array &Tokens) { + auto LogitsShape = Logits.shape(); + std::vector MaskVec(Logits.size(), 0.0f); + + // Suppress <|notimestamps|> which is handled by without_timestamps + if (Tokenizer->getNoTimestamps()) { + for (int I = 0; I < LogitsShape[0]; ++I) { + MaskVec[I * LogitsShape[1] + Tokenizer->getNoTimestamps()] = + -std::numeric_limits::infinity(); + } + } + + mx::eval(Tokens); + std::vector> TokensList(Tokens.shape(0)); + for (int K = 0; K < Tokens.shape(0); ++K) { + TokensList[K].resize(Tokens.shape(1)); + for (int J = 0; J < Tokens.shape(1); ++J) { + mx::array TokenVal = mx::take(mx::take(Tokens, K, 0), J, 0); + TokensList[K][J] = TokenVal.item(); + } + } + + // Timestamps have to appear in pairs, except directly before EOT; mask logits + // accordingly + for (int K = 0; K < static_cast(TokensList.size()); ++K) { + // seq = tokens[k][self.sample_begin :] + std::vector Seq(TokensList[K].begin() + SampleBegin, + TokensList[K].end()); + + bool LastWasTimestamp = + Seq.size() >= 1 && + Seq[Seq.size() - 1] >= Tokenizer->getTimestampBegin(); + bool PenultimateWasTimestamp = + Seq.size() < 2 || Seq[Seq.size() - 2] >= Tokenizer->getTimestampBegin(); + + if (LastWasTimestamp) { + if (PenultimateWasTimestamp) { + // Has to be non-timestamp + for (int I = Tokenizer->getTimestampBegin(); I < LogitsShape[1]; ++I) { + MaskVec[K * LogitsShape[1] + I] = + -std::numeric_limits::infinity(); + } + } else { + // Cannot be normal text tokens + for (int I = 0; I < Tokenizer->getEot(); ++I) { + MaskVec[K * LogitsShape[1] + I] = + -std::numeric_limits::infinity(); + } + } + } + + // Find timestamps in sequence and enforce monotonicity + std::vector Timestamps; + for (size_t I = 0; I < Seq.size(); ++I) { + if (Seq[I] > Tokenizer->getTimestampBegin()) { + Timestamps.push_back(Seq[I]); + } + } + + if (!Timestamps.empty()) { + // Timestamps shouldn't decrease; forbid timestamp tokens smaller than the + // last Also force each segment to have a nonzero length, to prevent + // infinite looping + int LastTimestamp = Timestamps.back(); + if (LastTimestamp == 0 || PenultimateWasTimestamp) { + LastTimestamp += 1; + } + for (int I = Tokenizer->getTimestampBegin(); I < LastTimestamp; ++I) { + MaskVec[K * LogitsShape[1] + I] = + -std::numeric_limits::infinity(); + } + } + } + + if (static_cast(TokensList[0].size()) == SampleBegin) { + // Suppress generating non-timestamp tokens at the beginning + for (int I = 0; I < LogitsShape[0]; ++I) { + for (int J = 0; J < Tokenizer->getTimestampBegin(); ++J) { + MaskVec[I * LogitsShape[1] + J] = + -std::numeric_limits::infinity(); + } + } + + // Apply the `max_initial_timestamp` option + if (MaxInitialTimestampIndex) { + int LastAllowed = + Tokenizer->getTimestampBegin() + *MaxInitialTimestampIndex; + for (int I = 0; I < LogitsShape[0]; ++I) { + for (int J = LastAllowed + 1; J < LogitsShape[1]; ++J) { + MaskVec[I * LogitsShape[1] + J] = + -std::numeric_limits::infinity(); + } + } + } + } + + // If the sum of probabilities over timestamps is above any other token, + // sample the timestamp. + mx::array MaskArray = mx::array(MaskVec.data(), LogitsShape, mx::float32); + mx::array Logprobs = Logits - mx::logsumexp(Logits, -1, true); + + // Calculate timestamp logprob: sum of probabilities for all timestamp tokens + mx::array TimestampLogprob = + mx::logsumexp(mx::slice(Logprobs, {0, Tokenizer->getTimestampBegin()}, + {LogitsShape[0], LogitsShape[1]}), + -1, true); + + // Calculate max text token logprob: max probability among non-timestamp + // tokens + mx::array MaxTextTokenLogprob = + mx::max(mx::slice(Logprobs, {0, 0}, + {LogitsShape[0], Tokenizer->getTimestampBegin()}), + -1, true); + + // Where timestamp probability > max text probability, suppress text tokens + mx::array TimestampCondition = + mx::greater(TimestampLogprob, MaxTextTokenLogprob); + + for (int I = 0; I < LogitsShape[0]; ++I) { + bool ShouldSuppressText = mx::take(TimestampCondition, I, 0).item(); + if (ShouldSuppressText) { + for (int J = 0; J < Tokenizer->getTimestampBegin(); ++J) { + MaskVec[I * LogitsShape[1] + J] = + -std::numeric_limits::infinity(); + } + } + } + + MaskArray = mx::array(MaskVec.data(), LogitsShape, mx::float32); + return Logits + MaskArray; +} + +// MaximumLikelihoodRanker implementation +MaximumLikelihoodRanker::MaximumLikelihoodRanker( + std::optional LengthPenalty) + : LengthPenalty(LengthPenalty) {} + +std::vector MaximumLikelihoodRanker::rank( + const std::vector>> &Tokens, + const std::vector> &SumLogprobs) { + + std::vector Selected; + + for (size_t I = 0; I < Tokens.size(); ++I) { + std::vector Scores; + + for (size_t J = 0; J < Tokens[I].size(); ++J) { + int Length = Tokens[I][J].size(); + float Logprob = SumLogprobs[I][J]; + + float Penalty; + if (LengthPenalty) { + Penalty = std::pow(Length, *LengthPenalty); + } else { + Penalty = Length; + } + + Scores.push_back(Logprob / Penalty); + } + + auto MaxIterator = std::max_element(Scores.begin(), Scores.end()); + Selected.push_back(std::distance(Scores.begin(), MaxIterator)); + } + + return Selected; +} + +// DecodingTask implementation - Constructor and helper methods +DecodingTask::DecodingTask(std::shared_ptr Model, + const DecodingOptions &Options) + : Model(Model), Options(verifyOptions(Options)) { + + std::string Language = Options.Language.value_or("en"); + Tokenizer = whisper::getTokenizer( + Model->isMultilingual(), Model->numLanguages(), Language, Options.Task); + + NGroup = Options.BeamSize.value_or(Options.BestOf.value_or(1)); + NCtx = Model->Dims.NTextCtx; + SampleLen = Options.SampleLen.value_or(NCtx / 2); + + // Handle SOT sequence with without_timestamps logic + SotSequence = Tokenizer->SotSequence; + if (Options.WithoutTimestamps) { + SotSequence = Tokenizer->getSotSequenceIncludingNotimestamps(); + } + + InitialTokens = getInitialTokens(); + SampleBegin = InitialTokens.size(); + + // Find SOT index + auto SotToken = Tokenizer->getSot(); + auto Iterator = + std::find(InitialTokens.begin(), InitialTokens.end(), SotToken); + SotIndex = std::distance(InitialTokens.begin(), Iterator); + + // Initialize components + Inference = std::make_unique(Model); + SequenceRanker = + std::make_unique(Options.LengthPenalty); + + if (Options.BeamSize && *Options.BeamSize > 1) { + throw std::runtime_error("Beam search decoder is not yet implemented"); + } + Decoder = + std::make_unique(Options.Temperature, Tokenizer->getEot()); + + LogitFilters.clear(); + + if (Options.SuppressBlank) { + LogitFilters.push_back(std::make_unique( + Tokenizer, SampleBegin, Model->Dims.NVocab)); + } + + if (Options.SuppressTokens) { + auto SuppressTokens = getSuppressTokens(); + LogitFilters.push_back(std::make_unique( + SuppressTokens, Model->Dims.NVocab)); + } + + if (!Options.WithoutTimestamps) { + std::optional MaxInitialTimestampIndex; + if (Options.MaxInitialTimestamp) { + float Precision = + 30.0f / Model->Dims.NAudioCtx; // CHUNK_LENGTH / n_audio_ctx + MaxInitialTimestampIndex = static_cast( + std::round(*Options.MaxInitialTimestamp / Precision)); + } + LogitFilters.push_back(std::make_unique( + Tokenizer, SampleBegin, MaxInitialTimestampIndex)); + } +} + +DecodingOptions +DecodingTask::verifyOptions(const DecodingOptions &InputOptions) { + DecodingOptions Result = InputOptions; + + // Check beam_size and best_of conflicts + if (Result.BeamSize && Result.BestOf) { + throw std::runtime_error("beam_size and best_of can't be given together"); + } + + // Check temperature = 0 with best_of + if (Result.Temperature == 0.0f && Result.BestOf) { + throw std::runtime_error( + "best_of with greedy sampling (T=0) is not compatible"); + } + + // Check patience requires beam_size + if (Result.Patience && !Result.BeamSize) { + throw std::runtime_error("patience requires beam_size to be given"); + } + + // Check length_penalty range + if (Result.LengthPenalty && + (*Result.LengthPenalty < 0.0f || *Result.LengthPenalty > 1.0f)) { + throw std::runtime_error( + "length_penalty (alpha) should be a value between 0 and 1"); + } + + return Result; +} + +std::vector DecodingTask::getInitialTokens() { + std::vector Tokens = SotSequence; + + if (Options.Prefix) { + std::vector PrefixTokens; + if (std::holds_alternative(*Options.Prefix)) { + std::string PrefixStr = std::get(*Options.Prefix); + PrefixTokens = Tokenizer->encode(" " + PrefixStr); + } else { + PrefixTokens = std::get>(*Options.Prefix); + } + + if (SampleLen > 0) { + int MaxPrefixLen = NCtx / 2 - SampleLen; + if (static_cast(PrefixTokens.size()) > MaxPrefixLen) { + PrefixTokens = std::vector(PrefixTokens.end() - MaxPrefixLen, + PrefixTokens.end()); + } + } + + Tokens.insert(Tokens.end(), PrefixTokens.begin(), PrefixTokens.end()); + } + + if (Options.Prompt) { + std::vector PromptTokens; + if (std::holds_alternative(*Options.Prompt)) { + std::string PromptStr = std::get(*Options.Prompt); + PromptTokens = Tokenizer->encode(" " + PromptStr); + } else { + PromptTokens = std::get>(*Options.Prompt); + } + + int MaxPromptLen = NCtx / 2 - 1; + if (static_cast(PromptTokens.size()) > MaxPromptLen) { + PromptTokens = std::vector(PromptTokens.end() - MaxPromptLen, + PromptTokens.end()); + } + + std::vector NewTokens; + NewTokens.push_back(Tokenizer->getSotPrev()); + NewTokens.insert(NewTokens.end(), PromptTokens.begin(), PromptTokens.end()); + NewTokens.insert(NewTokens.end(), Tokens.begin(), Tokens.end()); + Tokens = NewTokens; + } + + return Tokens; +} + +std::vector DecodingTask::getSuppressTokens() { + std::vector SuppressTokens; + + if (Options.SuppressTokens) { + if (std::holds_alternative(*Options.SuppressTokens)) { + std::string TokensString = std::get(*Options.SuppressTokens); + std::istringstream Iss(TokensString); + std::string TokenString; + while (std::getline(Iss, TokenString, ',')) { + if (!TokenString.empty()) { + SuppressTokens.push_back(std::stoi(TokenString)); + } + } + } else { + SuppressTokens = std::get>(*Options.SuppressTokens); + } + } + + auto Iterator = std::find(SuppressTokens.begin(), SuppressTokens.end(), -1); + if (Iterator != SuppressTokens.end()) { + SuppressTokens.erase(std::remove_if(SuppressTokens.begin(), + SuppressTokens.end(), + [](int Token) { return Token < 0; }), + SuppressTokens.end()); + auto NonSpeechTokens = Tokenizer->getNonSpeechTokens(); + SuppressTokens.insert(SuppressTokens.end(), NonSpeechTokens.begin(), + NonSpeechTokens.end()); + } else if (!Options.SuppressTokens || SuppressTokens.empty()) { + SuppressTokens.clear(); + } else { + assumingUnreachable(); + } + + SuppressTokens.push_back(Tokenizer->getTranscribe()); + SuppressTokens.push_back(Tokenizer->getTranslate()); + SuppressTokens.push_back(Tokenizer->getSot()); + SuppressTokens.push_back(Tokenizer->getSotPrev()); + SuppressTokens.push_back(Tokenizer->getSotLm()); + + SuppressTokens.push_back(Tokenizer->getNoSpeech()); + + std::sort(SuppressTokens.begin(), SuppressTokens.end()); + SuppressTokens.erase( + std::unique(SuppressTokens.begin(), SuppressTokens.end()), + SuppressTokens.end()); + + return SuppressTokens; +} + +mx::array DecodingTask::getAudioFeatures(const mx::array &Mel) { + bool Single = Mel.ndim() == 2; + mx::array MelArray = Single ? mx::expand_dims(Mel, 0) : Mel; + + mx::array AudioFeatures = MelArray; + + // Skip encoder forward pass if already-encoded audio features were given + if (AudioFeatures.shape(-2) != Model->Dims.NAudioCtx || + AudioFeatures.shape(-1) != Model->Dims.NAudioState) { + AudioFeatures = Model->embedAudio(AudioFeatures); + } + + return AudioFeatures; +} + +std::pair, + std::optional>>> +DecodingTask::detectLanguage(const mx::array &AudioFeatures, + mx::array &Tokens) { + + std::vector Languages(AudioFeatures.shape(0), + Options.Language.value_or("en")); + std::optional>> LangProbs; + + if (!Options.Language || Options.Task == "lang_id") { + // Call the global detectLanguage function + auto [DetectedLanguageTokens, Probabilities] = + whisper::detectLanguage(Model, AudioFeatures, Tokenizer); + + Languages.clear(); + for (const auto &ProbsMap : Probabilities) { + auto MaxIterator = std::max_element( + ProbsMap.begin(), ProbsMap.end(), + [](const auto &A, const auto &B) { return A.second < B.second; }); + Languages.push_back(MaxIterator->first); + } + + LangProbs = Probabilities; + + if (!Options.Language) { + std::vector Start = {0, SotIndex + 1}; + std::vector End = {Tokens.shape()[0], SotIndex + 2}; + + Tokens = slice_update(Tokens, DetectedLanguageTokens, Start, End); + } + } + + return {Languages, LangProbs}; +} + +std::tuple +DecodingTask::mainLoop(const mx::array &AudioFeatures, + const mx::array &Tokens) { + + int NBatch = Tokens.shape(0); + mx::array CurrentTokens = Tokens; + mx::array SumLogprobs = mx::zeros({NBatch}, mx::float32); + bool Completed = false; + + auto StepFunction = [&](const mx::array &Inputs, const mx::array &AudioFeats, + const mx::array &TokSeq, const mx::array &SumLogp) + -> std::tuple { + mx::array PreLogits = Inference->logits(Inputs, AudioFeats); + mx::array Logits = take(PreLogits, PreLogits.shape(1) - 1, 1); + for (const auto &Filter : LogitFilters) { + Logits = Filter->apply(Logits, TokSeq); + } + auto [NextTokens, CompletedFlag, NextSumLogprobs] = + Decoder->update(TokSeq, Logits, SumLogp); + return std::make_tuple(NextTokens, CompletedFlag, NextSumLogprobs, + PreLogits); + }; + + auto [NextTokens, CompletedFlag, NextSumLogprobs, PreLogits] = + StepFunction(CurrentTokens, AudioFeatures, CurrentTokens, SumLogprobs); + + CurrentTokens = NextTokens; + SumLogprobs = NextSumLogprobs; + Completed = CompletedFlag; + + mx::array NoSpeechProbs = mx::zeros({NBatch}, mx::float32); + if (Tokenizer->getNoSpeech() != -1) { + auto ProbsAtSot = mx::softmax(mx::take(PreLogits, SotIndex, 1), -1); + NoSpeechProbs = mx::take(ProbsAtSot, Tokenizer->getNoSpeech(), 1); + } else { + NoSpeechProbs = mx::full({NBatch}, std::numeric_limits::quiet_NaN(), + mx::float32); + } + + mx::eval(CurrentTokens, SumLogprobs, NoSpeechProbs); + for (int I = 1; I < SampleLen; ++I) { + mx::array Inputs = + take(CurrentTokens, mx::array({CurrentTokens.shape(1) - 1}), 1); + + if (CurrentTokens.shape(-1) > NCtx) { + break; + } + auto [NextToks, NextCompleted, NextSumLogp, _] = + StepFunction(Inputs, AudioFeatures, CurrentTokens, SumLogprobs); + mx::eval(NextToks, NextSumLogp); + + if (Completed) { + break; + } + + CurrentTokens = NextToks; + Completed = NextCompleted; + SumLogprobs = NextSumLogp; + } + + return std::make_tuple(CurrentTokens, SumLogprobs, NoSpeechProbs); +} + +std::vector DecodingTask::run(const mx::array &Mel) { + Inference->reset(); + Decoder->reset(); + int NAudio = Mel.shape(0); + + mx::array AudioFeatures = getAudioFeatures(Mel); + + mx::array Tokens = + mx::array(InitialTokens.data(), {static_cast(InitialTokens.size())}, + mx::int32); + Tokens = mx::broadcast_to(Tokens, + {NAudio, static_cast(InitialTokens.size())}); + auto [Languages, LangProbs] = detectLanguage(AudioFeatures, Tokens); + + if (Options.Task == "lang_id") { + std::vector Results; + for (int I = 0; I < NAudio; ++I) { + DecodingResult Result; + Result.AudioFeatures = mx::take(AudioFeatures, mx::array({I}), 0); + Result.Language = Languages[I]; + if (LangProbs) { + Result.LanguageProbs = (*LangProbs)[I]; + } + Results.push_back(Result); + } + return Results; + } + + if (NGroup > 1) { + // tokens = tokens[:, None, :] + Tokens = mx::expand_dims(Tokens, 1); + + // tokens = mx.broadcast_to(tokens, [n_audio, self.n_group, + // len(self.initial_tokens)]) + std::vector NewShape = {NAudio, NGroup, + static_cast(InitialTokens.size())}; + Tokens = mx::broadcast_to(Tokens, NewShape); + + // tokens = tokens.reshape(n_audio * self.n_group, len(self.initial_tokens)) + Tokens = mx::reshape( + Tokens, {NAudio * NGroup, static_cast(InitialTokens.size())}); + } + // Call the main sampling loop + auto [TokensResult, SumLogprobs, NoSpeechProbs] = + mainLoop(AudioFeatures, Tokens); + + // Reshape the tensors to have (n_audio, n_group) as the first two dimensions + AudioFeatures = + mx::take(AudioFeatures, mx::arange(0, AudioFeatures.shape(0), NGroup), 0); + NoSpeechProbs = + mx::take(NoSpeechProbs, mx::arange(0, NoSpeechProbs.shape(0), NGroup), 0); + + // Ensure shapes are consistent + if (AudioFeatures.shape(0) != NoSpeechProbs.shape(0) || + AudioFeatures.shape(0) != NAudio) { + throw std::runtime_error( + "Inconsistent audio features and no_speech_probs shapes"); + } + + TokensResult = mx::reshape(TokensResult, {NAudio, NGroup, -1}); + SumLogprobs = mx::reshape(SumLogprobs, {NAudio, NGroup}); + // Get the final candidates for each group, and slice between the first + // sampled token and EOT + auto [FinalizedTokens, FinalizedLogprobs] = + Decoder->finalize(TokensResult, SumLogprobs); + + // tokens[..., self.sample_begin:] + std::vector SliceStart(FinalizedTokens.ndim(), 0); + std::vector SliceEnd = FinalizedTokens.shape(); + SliceStart[FinalizedTokens.ndim() - 1] = SampleBegin; + FinalizedTokens = mx::slice(FinalizedTokens, SliceStart, SliceEnd); + + mx::eval(FinalizedTokens, FinalizedLogprobs, NoSpeechProbs); + + // Convert tokens to nested vectors and handle EOT + std::vector>> TokensList(NAudio); + std::vector> SumLogprobsList(NAudio); + + for (int I = 0; I < NAudio; ++I) { + TokensList[I].resize(NGroup); + SumLogprobsList[I].resize(NGroup); + + for (int J = 0; J < NGroup; ++J) { + // Extract tokens until EOT + for (int K = 0; K < FinalizedTokens.shape(2); ++K) { + int Dim1 = FinalizedTokens.shape(1) * FinalizedTokens.shape(2); + int Dim2 = FinalizedTokens.shape(2); + mx::array TokenArray = + mx::take(FinalizedTokens, mx::array({I * Dim1 + J * Dim2 + K})); + int Token = TokenArray.item(); + if (Token == Tokenizer->getEot()) + break; + TokensList[I][J].push_back(Token); + } + + // Extract sum_logprobs + mx::array LogprobArray = mx::take( + FinalizedLogprobs, mx::array({I * FinalizedLogprobs.shape(1) + J})); + SumLogprobsList[I][J] = LogprobArray.item(); + } + } + + std::vector Selected = SequenceRanker->rank(TokensList, SumLogprobsList); + + // Extract final results + std::vector> FinalTokens(NAudio); + std::vector Texts(NAudio); + std::vector FinalSumLogprobs(NAudio); + std::vector AvgLogprobs(NAudio); + + for (int I = 0; I < NAudio; ++I) { + int SelectedIdx = Selected[I]; + FinalTokens[I] = TokensList[I][SelectedIdx]; + Texts[I] = Tokenizer->decode(FinalTokens[I]); + + Texts[I].erase(0, Texts[I].find_first_not_of(" \t\n\r\f\v")); + Texts[I].erase(Texts[I].find_last_not_of(" \t\n\r\f\v") + 1); + + FinalSumLogprobs[I] = SumLogprobsList[I][SelectedIdx]; + AvgLogprobs[I] = FinalSumLogprobs[I] / (FinalTokens[I].size() + 1); + } + + if (Texts.size() != Languages.size() || + Languages.size() != FinalTokens.size() || + FinalTokens.size() != AvgLogprobs.size() || + static_cast(AvgLogprobs.size()) != NAudio) { + throw std::runtime_error("inconsistent result lengths"); + } + + // Create final results + std::vector Results; + for (int I = 0; I < NAudio; ++I) { + DecodingResult Result; + Result.AudioFeatures = mx::take(AudioFeatures, I, 0); + Result.Language = Languages[I]; + Result.Tokens = FinalTokens[I]; + Result.Text = Texts[I]; + Result.AvgLogprob = AvgLogprobs[I]; + + mx::array NoSpeechArray = mx::take(NoSpeechProbs, mx::array({I}), 0); + Result.NoSpeechProb = NoSpeechArray.item(); + + Result.Temperature = Options.Temperature; + Result.CompressionRatio = compressionRatio(Result.Text); + + if (LangProbs) { + Result.LanguageProbs = (*LangProbs)[I]; + } + Results.push_back(Result); + } + + return Results; +} + +// Main decode function +std::variant> +decode(std::shared_ptr Model, const mx::array &Mel, + const DecodingOptions &Options) { + auto MelArray = Mel; + if (Mel.ndim() == 2) { + auto NewShape = Mel.shape(); + NewShape.insert(NewShape.begin(), 1); + MelArray = reshape(Mel, NewShape); + } + auto Results = DecodingTask(Model, Options).run(MelArray); + + if (Results.size() == 1) { + return Results[0]; + } + return Results; +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.h b/plugins/wasi_nn/MLX/model/whisper/decoding.h new file mode 100644 index 00000000..84f05d76 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.h @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include "whisper.h" +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +class Tokenizer; +class Whisper; + +float compressionRatio(const std::string &Text); + +std::pair>> +detectLanguage(std::shared_ptr Model, const mx::array &Mel, + std::shared_ptr Tokenizer = nullptr); + +struct DecodingOptions { + std::string Task = "transcribe"; + std::optional Language = std::nullopt; + float Temperature = 0.0f; + std::optional SampleLen = std::nullopt; + std::optional BestOf = std::nullopt; + std::optional BeamSize = std::nullopt; + std::optional Patience = std::nullopt; + std::optional LengthPenalty = std::nullopt; + std::optional>> Prompt = + std::nullopt; + std::optional>> Prefix = + std::nullopt; + std::optional>> SuppressTokens = + "-1"; + bool SuppressBlank = true; + bool WithoutTimestamps = false; + std::optional MaxInitialTimestamp = 1.0f; + bool Fp16 = true; +}; + +struct DecodingResult { + mx::array AudioFeatures; + std::string Language; + std::optional> LanguageProbs = std::nullopt; + std::vector Tokens; + std::string Text = ""; + float AvgLogprob = std::numeric_limits::quiet_NaN(); + float NoSpeechProb = std::numeric_limits::quiet_NaN(); + float Temperature = std::numeric_limits::quiet_NaN(); + float CompressionRatio = std::numeric_limits::quiet_NaN(); + DecodingResult() : AudioFeatures(mx::array({})) {} +}; + +class Inference { +public: + explicit Inference(std::shared_ptr Model); + mx::array logits(const mx::array &Tokens, const mx::array &AudioFeatures); + void rearrangeKvCache(const std::vector &SourceIndices); + void reset(); + +private: + std::shared_ptr Model; + std::optional< + std::vector>, + std::optional>>>> + KvCache = std::nullopt; +}; + +class SequenceRanker { +public: + virtual ~SequenceRanker() = default; + virtual std::vector + rank(const std::vector>> &Tokens, + const std::vector> &SumLogprobs) = 0; +}; + +class MaximumLikelihoodRanker : public SequenceRanker { +public: + explicit MaximumLikelihoodRanker(std::optional LengthPenalty); + std::vector + rank(const std::vector>> &Tokens, + const std::vector> &SumLogprobs) override; + +private: + std::optional LengthPenalty; +}; + +class TokenDecoder { +public: + virtual ~TokenDecoder() = default; + virtual void reset() = 0; + virtual std::tuple + update(const mx::array &Tokens, const mx::array &Logits, + const mx::array &SumLogprobs) = 0; + virtual std::pair + finalize(const mx::array &Tokens, const mx::array &SumLogprobs) = 0; +}; + +class GreedyDecoder : public TokenDecoder { +public: + GreedyDecoder(float Temperature, int Eot); + void reset() override; + std::tuple + update(const mx::array &Tokens, const mx::array &Logits, + const mx::array &SumLogprobs) override; + std::pair + finalize(const mx::array &Tokens, const mx::array &SumLogprobs) override; + +private: + float Temperature; + int Eot; +}; + +class LogitFilter { +public: + virtual ~LogitFilter() = default; + virtual mx::array apply(const mx::array &Logits, const mx::array &Tokens) = 0; + std::string Name = "LogitFilter"; +}; + +class SuppressBlank : public LogitFilter { +public: + SuppressBlank(std::shared_ptr Tokenizer, int SampleBegin, + int NVocab); + mx::array apply(const mx::array &Logits, const mx::array &Tokens) override; + +private: + int SampleBegin; + mx::array Mask; +}; + +class SuppressTokens : public LogitFilter { +public: + SuppressTokens(const std::vector &SuppressTokens, int NVocab); + mx::array apply(const mx::array &Logits, const mx::array &Tokens) override; + +private: + mx::array Mask; +}; + +class ApplyTimestampRules : public LogitFilter { +public: + ApplyTimestampRules(std::shared_ptr Tokenizer, int SampleBegin, + std::optional MaxInitialTimestampIndex); + mx::array apply(const mx::array &Logits, const mx::array &Tokens) override; + +private: + std::shared_ptr Tokenizer; + int SampleBegin; + std::optional MaxInitialTimestampIndex; +}; + +class DecodingTask { +public: + DecodingTask(std::shared_ptr Model, const DecodingOptions &Options); + std::vector run(const mx::array &Mel); + +private: + DecodingOptions verifyOptions(const DecodingOptions &InputOptions); + std::vector getInitialTokens(); + std::vector getSuppressTokens(); + mx::array getAudioFeatures(const mx::array &Mel); + std::pair, + std::optional>>> + detectLanguage(const mx::array &AudioFeatures, mx::array &Tokens); + std::tuple + mainLoop(const mx::array &AudioFeatures, const mx::array &Tokens); + std::shared_ptr Model; + std::shared_ptr Tokenizer; + DecodingOptions Options; + int NGroup; + int NCtx; + int SampleLen; + std::vector SotSequence; + std::vector InitialTokens; + int SampleBegin; + int SotIndex; + std::unique_ptr Inference; + std::unique_ptr SequenceRanker; + std::unique_ptr Decoder; + std::vector> LogitFilters; +}; + +std::variant> +decode(std::shared_ptr Model, const mx::array &Mel, + const DecodingOptions &Options = DecodingOptions()); + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp new file mode 100644 index 00000000..4178261f --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp @@ -0,0 +1,784 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "tokenizer.h" +#include "mlx/base.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +const std::vector> LANGUAGES = { + {"en", "english"}, {"zh", "chinese"}, {"de", "german"}, + {"es", "spanish"}, {"ru", "russian"}, {"ko", "korean"}, + {"fr", "french"}, {"ja", "japanese"}, {"pt", "portuguese"}, + {"tr", "turkish"}, {"pl", "polish"}, {"ca", "catalan"}, + {"nl", "dutch"}, {"ar", "arabic"}, {"sv", "swedish"}, + {"it", "italian"}, {"id", "indonesian"}, {"hi", "hindi"}, + {"fi", "finnish"}, {"vi", "vietnamese"}, {"he", "hebrew"}, + {"uk", "ukrainian"}, {"el", "greek"}, {"ms", "malay"}, + {"cs", "czech"}, {"ro", "romanian"}, {"da", "danish"}, + {"hu", "hungarian"}, {"ta", "tamil"}, {"no", "norwegian"}, + {"th", "thai"}, {"ur", "urdu"}, {"hr", "croatian"}, + {"bg", "bulgarian"}, {"lt", "lithuanian"}, {"la", "latin"}, + {"mi", "maori"}, {"ml", "malayalam"}, {"cy", "welsh"}, + {"sk", "slovak"}, {"te", "telugu"}, {"fa", "persian"}, + {"lv", "latvian"}, {"bn", "bengali"}, {"sr", "serbian"}, + {"az", "azerbaijani"}, {"sl", "slovenian"}, {"kn", "kannada"}, + {"et", "estonian"}, {"mk", "macedonian"}, {"br", "breton"}, + {"eu", "basque"}, {"is", "icelandic"}, {"hy", "armenian"}, + {"ne", "nepali"}, {"mn", "mongolian"}, {"bs", "bosnian"}, + {"kk", "kazakh"}, {"sq", "albanian"}, {"sw", "swahili"}, + {"gl", "galician"}, {"mr", "marathi"}, {"pa", "punjabi"}, + {"si", "sinhala"}, {"km", "khmer"}, {"sn", "shona"}, + {"yo", "yoruba"}, {"so", "somali"}, {"af", "afrikaans"}, + {"oc", "occitan"}, {"ka", "georgian"}, {"be", "belarusian"}, + {"tg", "tajik"}, {"sd", "sindhi"}, {"gu", "gujarati"}, + {"am", "amharic"}, {"yi", "yiddish"}, {"lo", "lao"}, + {"uz", "uzbek"}, {"fo", "faroese"}, {"ht", "haitian creole"}, + {"ps", "pashto"}, {"tk", "turkmen"}, {"nn", "nynorsk"}, + {"mt", "maltese"}, {"sa", "sanskrit"}, {"lb", "luxembourgish"}, + {"my", "myanmar"}, {"bo", "tibetan"}, {"tl", "tagalog"}, + {"mg", "malagasy"}, {"as", "assamese"}, {"tt", "tatar"}, + {"haw", "hawaiian"}, {"ln", "lingala"}, {"ha", "hausa"}, + {"ba", "bashkir"}, {"jw", "javanese"}, {"su", "sundanese"}, + {"yue", "cantonese"}}; + +// Helper function to find language by code +std::string findLanguageByCode(const std::string &Code) { + for (const auto &[LangCode, Language] : LANGUAGES) { + if (LangCode == Code) { + return Language; + } + } + return ""; +} + +// Helper function to check if language code exists +bool languageCodeExists(const std::string &Code) { + for (const auto &[LangCode, Language] : LANGUAGES) { + if (LangCode == Code) { + return true; + } + } + return false; +} + +const std::unordered_map ToLanguageCode = []() { + std::unordered_map ToLanguageCodeMap; + + // Add language to code mappings + for (const auto &[Code, Language] : LANGUAGES) { + ToLanguageCodeMap[Language] = Code; + } + + // Add aliases + ToLanguageCodeMap["burmese"] = "my"; + ToLanguageCodeMap["valencian"] = "ca"; + ToLanguageCodeMap["flemish"] = "nl"; + ToLanguageCodeMap["haitian"] = "ht"; + ToLanguageCodeMap["letzeburgesch"] = "lb"; + ToLanguageCodeMap["pushto"] = "ps"; + ToLanguageCodeMap["panjabi"] = "pa"; + ToLanguageCodeMap["moldavian"] = "ro"; + ToLanguageCodeMap["moldovan"] = "ro"; + ToLanguageCodeMap["sinhalese"] = "si"; + ToLanguageCodeMap["castilian"] = "es"; + ToLanguageCodeMap["mandarin"] = "zh"; + + return ToLanguageCodeMap; +}(); + +namespace { +// Modern base64 decode implementation +// Based on RFC 4648 standard with better error handling +std::string base64Decode(const std::string &Encoded) { + // Standard base64 alphabet lookup table + static const std::array DecodeTable = []() { + std::array Table{}; + // Initialize all values to -1 (invalid) + std::fill(Table.begin(), Table.end(), -1); + + // Set valid characters + for (int I = 0; I < 26; ++I) { + Table['A' + I] = I; // A-Z: 0-25 + Table['a' + I] = I + 26; // a-z: 26-51 + } + for (int I = 0; I < 10; ++I) { + Table['0' + I] = I + 52; // 0-9: 52-61 + } + Table['+'] = 62; + Table['/'] = 63; + // Padding character + Table['='] = -2; + + return Table; + }(); + + if (Encoded.empty() || Encoded == "=") { + return {}; + } + + // Remove whitespace and validate length + std::string Cleaned; + Cleaned.reserve(Encoded.size()); + for (char C : Encoded) { + if (C != ' ' && C != '\t' && C != '\r' && C != '\n') { + Cleaned.push_back(C); + } + } + + if (Cleaned.size() % 4 != 0) { + throw std::invalid_argument("Invalid base64 string length"); + } + + std::string Result; + Result.reserve((Cleaned.size() * 3) / 4); + + for (size_t I = 0; I < Cleaned.size(); I += 4) { + std::array Values{}; + int PaddingCount = 0; + + // Decode 4 characters at a time + for (int J = 0; J < 4; ++J) { + unsigned char C = static_cast(Cleaned[I + J]); + int Val = DecodeTable[C]; + + if (Val == -1) { + throw std::invalid_argument("Invalid base64 character: " + + std::to_string(C)); + } + if (Val == -2) { // Padding + Values[J] = 0; + ++PaddingCount; + } else { + if (PaddingCount > 0) { + throw std::invalid_argument("Invalid padding in base64 string"); + } + Values[J] = Val; + } + } + + // Convert 4 base64 chars to 3 bytes + int Combined = + (Values[0] << 18) | (Values[1] << 12) | (Values[2] << 6) | Values[3]; + + // Extract bytes based on padding + if (PaddingCount == 0) { + Result.push_back(static_cast((Combined >> 16) & 0xFF)); + Result.push_back(static_cast((Combined >> 8) & 0xFF)); + Result.push_back(static_cast(Combined & 0xFF)); + } else if (PaddingCount == 1) { + Result.push_back(static_cast((Combined >> 16) & 0xFF)); + Result.push_back(static_cast((Combined >> 8) & 0xFF)); + } else if (PaddingCount == 2) { + Result.push_back(static_cast((Combined >> 16) & 0xFF)); + } else { + throw std::invalid_argument("Invalid padding count in base64 string"); + } + } + + return Result; +} +} // namespace + +// Encoding implementation +Encoding::Encoding(const std::string &Name, int ExplicitNVocab, + const std::string &PatStr, + const std::unordered_map &MergeableRanks, + const std::unordered_map &SpecialTokens) + : Name(Name), PatStr(PatStr), MergeableRanks(MergeableRanks), + SpecialTokens(SpecialTokens) { + + EotToken = 50257; // Default EOT token + + // Build reverse mapping + for (const auto &[Token, Id] : MergeableRanks) { + TokenToString[Id] = Token; + } + for (const auto &[Token, Id] : SpecialTokens) { + TokenToString[Id] = Token; + SpecialTokensSet.insert(Token); + } +} + +std::vector Encoding::encode(const std::string &Text) const { + std::vector Result; + + if (Text.empty()) { + return Result; + } + + // Handle the simplified cases for symbols that are commonly in vocab + // First try to encode the text as a single token + auto DirectIt = MergeableRanks.find(Text); + if (DirectIt != MergeableRanks.end()) { + Result.push_back(DirectIt->second); + return Result; + } + + // For complex text, we need a better approach than just splitting by spaces + // This is a more sophisticated approach that handles individual characters + // and symbols + + size_t I = 0; + while (I < Text.length()) { + // Try to find the longest matching token starting at position I + std::string LongestMatch; + int LongestMatchToken = -1; + + // Look for longest match from current position + for (size_t Len = std::min(Text.length() - I, size_t(10)); Len > 0; --Len) { + std::string Candidate = Text.substr(I, Len); + auto It = MergeableRanks.find(Candidate); + if (It != MergeableRanks.end()) { + LongestMatch = Candidate; + LongestMatchToken = It->second; + break; // Take the first (longest) match + } + } + + if (LongestMatchToken != -1) { + Result.push_back(LongestMatchToken); + I += LongestMatch.length(); + } else { + // If no match found, try single character + std::string SingleChar = Text.substr(I, 1); + auto CharIt = MergeableRanks.find(SingleChar); + if (CharIt != MergeableRanks.end()) { + Result.push_back(CharIt->second); + } else { + // Handle UTF-8 sequences - try to find multi-byte character + size_t CharLen = 1; + unsigned char FirstByte = static_cast(Text[I]); + if ((FirstByte & 0x80) != 0) { + // UTF-8 multi-byte character + if ((FirstByte & 0xE0) == 0xC0) + CharLen = 2; + else if ((FirstByte & 0xF0) == 0xE0) + CharLen = 3; + else if ((FirstByte & 0xF8) == 0xF0) + CharLen = 4; + } + + if (I + CharLen <= Text.length()) { + std::string MultiByteChar = Text.substr(I, CharLen); + auto MultiIt = MergeableRanks.find(MultiByteChar); + if (MultiIt != MergeableRanks.end()) { + Result.push_back(MultiIt->second); + I += CharLen; + continue; + } + } + + // If we still can't find it, encode as bytes + for (size_t J = 0; J < CharLen && I + J < Text.length(); ++J) { + unsigned char Byte = static_cast(Text[I + J]); + // Try to find byte encoding - tiktoken often has byte-level tokens + std::string ByteStr(1, static_cast(Byte)); + auto ByteIt = MergeableRanks.find(ByteStr); + if (ByteIt != MergeableRanks.end()) { + Result.push_back(ByteIt->second); + } + } + I += CharLen; + } + } + } + + return Result; +} + +std::string Encoding::decode(const std::vector &TokenIds) const { + std::string Result; + for (int Id : TokenIds) { + auto It = TokenToString.find(Id); + if (It != TokenToString.end()) { + Result += It->second; + } + } + return Result; +} + +int Encoding::encodeSingleToken(const std::string &Token) const { + auto It = SpecialTokens.find(Token); + if (It != SpecialTokens.end()) { + return It->second; + } + + auto MergeIt = MergeableRanks.find(Token); + if (MergeIt != MergeableRanks.end()) { + return MergeIt->second; + } + + throw std::runtime_error("Token not found: " + Token); +} +// Tokenizer implementation +Tokenizer::Tokenizer(std::unique_ptr EncodingPtr, int NumLanguages, + const std::optional &Language, + const std::optional &Task) + : EncodingPtr(std::move(EncodingPtr)), NumLanguages(NumLanguages), + Language(Language), Task(Task) { + + for (const auto &Special : this->EncodingPtr->SpecialTokensSet) { + int SpecialToken = this->EncodingPtr->encodeSingleToken(Special); + SpecialTokens[Special] = SpecialToken; + } + + // Build SOT sequence + int Sot = SpecialTokens["<|startoftranscript|>"]; + int Translate = SpecialTokens["<|translate|>"]; + int Transcribe = SpecialTokens["<|transcribe|>"]; + std::vector Langs; + for (const auto &[Code, Name] : LANGUAGES) { + Langs.push_back(Code); + if (static_cast(Langs.size()) >= NumLanguages) + break; + } + + SotSequence = {Sot}; + if (Language.has_value()) { + auto LangIt = std::find(Langs.begin(), Langs.end(), Language.value()); + if (LangIt != Langs.end()) { + int LangIndex = std::distance(Langs.begin(), LangIt); + SotSequence.push_back(Sot + 1 + LangIndex); + } + } + if (Task.has_value()) { + int TaskToken = (Task.value() == "transcribe") ? Transcribe : Translate; + SotSequence.push_back(TaskToken); + } +} + +std::vector Tokenizer::encode(const std::string &Text) const { + return EncodingPtr->encode(Text); +} + +std::string Tokenizer::decode(const std::vector &TokenIds) const { + std::vector FilteredTokens; + int TimestampBegin = getTimestampBegin(); + + for (int Token : TokenIds) { + if (Token < TimestampBegin) { + FilteredTokens.push_back(Token); + } + } + + return EncodingPtr->decode(FilteredTokens); +} + +std::string +Tokenizer::decodeWithTimestamps(const std::vector &TokenIds) const { + return EncodingPtr->decode(TokenIds); +} + +// Cached property implementations +int Tokenizer::getEot() const { + if (!CachedEot.has_value()) { + CachedEot = EncodingPtr->EotToken; + } + return CachedEot.value(); +} + +int Tokenizer::getTranscribe() const { + if (!CachedTranscribe.has_value()) { + CachedTranscribe = SpecialTokens.at("<|transcribe|>"); + } + return CachedTranscribe.value(); +} + +int Tokenizer::getTranslate() const { + if (!CachedTranslate.has_value()) { + CachedTranslate = SpecialTokens.at("<|translate|>"); + } + return CachedTranslate.value(); +} + +int Tokenizer::getSot() const { + if (!CachedSot.has_value()) { + CachedSot = SpecialTokens.at("<|startoftranscript|>"); + } + return CachedSot.value(); +} + +int Tokenizer::getSotLm() const { + if (!CachedSotLm.has_value()) { + CachedSotLm = SpecialTokens.at("<|startoflm|>"); + } + return CachedSotLm.value(); +} + +int Tokenizer::getSotPrev() const { + if (!CachedSotPrev.has_value()) { + CachedSotPrev = SpecialTokens.at("<|startofprev|>"); + } + return CachedSotPrev.value(); +} + +int Tokenizer::getNoSpeech() const { + if (!CachedNoSpeech.has_value()) { + CachedNoSpeech = SpecialTokens.at("<|nospeech|>"); + } + return CachedNoSpeech.value(); +} + +int Tokenizer::getNoTimestamps() const { + if (!CachedNoTimestamps.has_value()) { + CachedNoTimestamps = SpecialTokens.at("<|notimestamps|>"); + } + return CachedNoTimestamps.value(); +} + +int Tokenizer::getTimestampBegin() const { + if (!CachedTimestampBegin.has_value()) { + CachedTimestampBegin = SpecialTokens.at("<|0.00|>"); + } + return CachedTimestampBegin.value(); +} + +int Tokenizer::languageToken() const { + if (!Language.has_value()) { + throw std::runtime_error( + "This tokenizer does not have language token configured"); + } + return toLanguageToken(Language.value()); +} + +int Tokenizer::toLanguageToken(const std::string &LanguageCode) const { + std::string TokenName = "<|" + LanguageCode + "|>"; + auto It = SpecialTokens.find(TokenName); + if (It != SpecialTokens.end()) { + return It->second; + } + throw std::runtime_error("Language " + LanguageCode + + " not found in tokenizer."); +} + +std::vector Tokenizer::getAllLanguageTokens() const { + if (!CachedAllLanguageTokens.has_value()) { + std::vector Result; + + std::vector SortedLanguageTokens; + for (const auto &[Token, TokenId] : SpecialTokens) { + std::string Stripped = Token; + const std::string CharsToStrip = "<|>"; + while (!Stripped.empty() && + CharsToStrip.find(Stripped.front()) != std::string::npos) { + Stripped.erase(0, 1); + } + while (!Stripped.empty() && + CharsToStrip.find(Stripped.back()) != std::string::npos) { + Stripped.pop_back(); + } + if (languageCodeExists(Stripped)) { + SortedLanguageTokens.push_back(Token); + } + } + + std::sort(SortedLanguageTokens.begin(), SortedLanguageTokens.end()); + + for (const std::string &Token : SortedLanguageTokens) { + auto It = SpecialTokens.find(Token); + if (It != SpecialTokens.end()) { + Result.push_back(It->second); + } + } + if (static_cast(Result.size()) > NumLanguages) { + Result.resize(NumLanguages); + } + + CachedAllLanguageTokens = Result; + } + return CachedAllLanguageTokens.value(); +} + +std::vector Tokenizer::getAllLanguageCodes() const { + if (!CachedAllLanguageCodes.has_value()) { + std::vector Result; + auto LanguageTokens = getAllLanguageTokens(); + + for (int Token : LanguageTokens) { + std::string Decoded = EncodingPtr->decode({Token}); + // Remove <| and |> from decoded token + if (Decoded.size() > 4 && Decoded.substr(0, 2) == "<|" && + Decoded.substr(Decoded.size() - 2) == "|>") { + Result.push_back(Decoded.substr(2, Decoded.size() - 4)); + } + } + + CachedAllLanguageCodes = Result; + } + return CachedAllLanguageCodes.value(); +} + +std::vector Tokenizer::getSotSequenceIncludingNotimestamps() const { + if (!CachedSotSequenceIncludingNotimestamps.has_value()) { + std::vector Result = SotSequence; + Result.push_back(getNoTimestamps()); + CachedSotSequenceIncludingNotimestamps = Result; + } + return CachedSotSequenceIncludingNotimestamps.value(); +} + +std::vector Tokenizer::getNonSpeechTokens() const { + if (!CachedNonSpeechTokens.has_value()) { + std::unordered_set ResultSet; + + std::vector Symbols = {"\"", "#", "(", ")", "*", "+", "/", + ":", ";", "<", "=", ">", "@", "[", + "\\", "]", "^", "_", "`", "{", "|", + "}", "~", "「", "」", "『", "』"}; + + std::vector AdditionalSymbols = { + "<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", + "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪"}; + Symbols.insert(Symbols.end(), AdditionalSymbols.begin(), + AdditionalSymbols.end()); + + std::vector Miscellaneous = {"♩", "♪", "♫", "♬", + "♭", "♮", "♯"}; + + auto DashTokens = EncodingPtr->encode(" -"); + if (!DashTokens.empty()) { + ResultSet.insert(DashTokens[0]); + } + + auto QuoteTokens = EncodingPtr->encode(" '"); + if (!QuoteTokens.empty()) { + ResultSet.insert(QuoteTokens[0]); + } + + // Process all symbols + std::vector AllSymbols = Symbols; + AllSymbols.insert(AllSymbols.end(), Miscellaneous.begin(), + Miscellaneous.end()); + + for (const std::string &Symbol : AllSymbols) { + try { + // Try encoding the symbol directly + auto DirectTokens = EncodingPtr->encode(Symbol); + bool IsMiscellaneous = + std::find(Miscellaneous.begin(), Miscellaneous.end(), Symbol) != + Miscellaneous.end(); + + if (DirectTokens.size() == 1 || IsMiscellaneous) { + if (!DirectTokens.empty()) { + ResultSet.insert(DirectTokens[0]); + } + } + + // Try encoding the symbol with a space prefix + auto SpacedTokens = EncodingPtr->encode(" " + Symbol); + if (SpacedTokens.size() == 1 || IsMiscellaneous) { + if (!SpacedTokens.empty()) { + ResultSet.insert(SpacedTokens[0]); + } + } + } catch (...) { + // Ignore encoding errors for individual symbols + } + } + + // Convert set to sorted vector + std::vector Result(ResultSet.begin(), ResultSet.end()); + std::sort(Result.begin(), Result.end()); + + CachedNonSpeechTokens = Result; + } + return CachedNonSpeechTokens.value(); +} + +std::pair, std::vector>> +Tokenizer::splitToWordTokens(const std::vector &Tokens) const { + std::unordered_set NoSpaceLanguages = {"zh", "ja", "th", + "lo", "my", "yue"}; + + if (Language.has_value() && NoSpaceLanguages.count(Language.value())) { + return splitTokensOnUnicode(Tokens); + } + + return splitTokensOnSpaces(Tokens); +} + +std::pair, std::vector>> +Tokenizer::splitTokensOnUnicode(const std::vector &Tokens) const { + std::string DecodedFull = decodeWithTimestamps(Tokens); + const std::string ReplacementChar = "\uFFFD"; + + std::vector Words; + std::vector> WordTokens; + std::vector CurrentTokens; + size_t UnicodeOffset = 0; + + for (int Token : Tokens) { + CurrentTokens.push_back(Token); + std::string Decoded = decodeWithTimestamps(CurrentTokens); + + bool ValidUnicode = + (Decoded.find(ReplacementChar) == std::string::npos) || + (UnicodeOffset + Decoded.find(ReplacementChar) < DecodedFull.size() && + DecodedFull.substr(UnicodeOffset + Decoded.find(ReplacementChar), + ReplacementChar.length()) == ReplacementChar); + + if (ValidUnicode) { + Words.push_back(Decoded); + WordTokens.push_back(CurrentTokens); + CurrentTokens.clear(); + UnicodeOffset += Decoded.size(); + } + } + + return {Words, WordTokens}; +} + +std::pair, std::vector>> +Tokenizer::splitTokensOnSpaces(const std::vector &Tokens) const { + auto [Subwords, SubwordTokensList] = splitTokensOnUnicode(Tokens); + + std::vector Words; + std::vector> WordTokens; + + for (size_t I = 0; I < Subwords.size(); ++I) { + const std::string &Subword = Subwords[I]; + const std::vector &SubwordTokens = SubwordTokensList[I]; + + bool Special = !SubwordTokens.empty() && SubwordTokens[0] >= getEot(); + bool WithSpace = !Subword.empty() && Subword[0] == ' '; + + // Check if it's punctuation + std::string Trimmed = Subword; + Trimmed.erase(0, Trimmed.find_first_not_of(" \t\n\r")); + Trimmed.erase(Trimmed.find_last_not_of(" \t\n\r") + 1); + bool Punctuation = Trimmed.size() == 1 && std::ispunct(Trimmed[0]); + + if (Special || WithSpace || Punctuation || Words.empty()) { + Words.push_back(Subword); + WordTokens.push_back(SubwordTokens); + } else { + Words.back() += Subword; + WordTokens.back().insert(WordTokens.back().end(), SubwordTokens.begin(), + SubwordTokens.end()); + } + } + + return {Words, WordTokens}; +} + +// Factory functions +std::unique_ptr getEncoding(const std::string &Name, + int NumLanguages) { + std::string VocabPath = "assets/" + Name + ".tiktoken"; + + std::unordered_map Ranks; + std::ifstream File(VocabPath); + + if (!File.is_open()) { + throw std::runtime_error("Failed to open vocab file: " + VocabPath); + } + + std::string Line; + while (std::getline(File, Line)) { + if (Line.empty()) + continue; + + std::istringstream Iss(Line); + std::string Token, RankStr; + if (Iss >> Token >> RankStr) { + std::string DecodedToken = base64Decode(Token); + Ranks[DecodedToken] = std::stoi(RankStr); + } + } + File.close(); + + int NVocab = Ranks.size(); + std::unordered_map SpecialTokens; + + // Build special tokens + std::vector Specials = {"<|endoftext|>", + "<|startoftranscript|>"}; + + for (const auto &[Code, LanguageName] : LANGUAGES) { + Specials.push_back("<|" + Code + "|>"); + if (static_cast(Specials.size()) >= NumLanguages + 2) + break; + } + + Specials.insert(Specials.end(), + {"<|translate|>", "<|transcribe|>", "<|startoflm|>", + "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>"}); + + // Add timestamp tokens + for (int I = 0; I < 1501; ++I) { + std::ostringstream Oss; + Oss << "<|" << std::fixed << std::setprecision(2) << (I * 0.02) << "|>"; + Specials.push_back(Oss.str()); + } + + for (const std::string &Token : Specials) { + SpecialTokens[Token] = NVocab++; + } + + std::string PatStr = + R"('s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+)"; + + return std::make_unique(VocabPath, NVocab, PatStr, Ranks, + SpecialTokens); +} + +std::unique_ptr +getTokenizer(bool Multilingual, int NumLanguages, + const std::optional &Language, + const std::optional &Task) { + std::optional ProcessedLanguage = Language; + + if (ProcessedLanguage.has_value()) { + std::string Lang = ProcessedLanguage.value(); + std::transform(Lang.begin(), Lang.end(), Lang.begin(), ::tolower); + + if (!languageCodeExists(Lang)) { + auto It = ToLanguageCode.find(Lang); + if (It != ToLanguageCode.end()) { + ProcessedLanguage = It->second; + } else { + throw std::runtime_error("Unsupported language: " + Lang); + } + } else { + ProcessedLanguage = Lang; + } + } + + std::string EncodingName; + std::optional FinalLanguage; + std::optional FinalTask; + + if (Multilingual) { + EncodingName = "multilingual"; + FinalLanguage = ProcessedLanguage.value_or("en"); + FinalTask = Task.value_or("transcribe"); + } else { + EncodingName = "gpt2"; + FinalLanguage = std::nullopt; + FinalTask = std::nullopt; + } + + auto Encoding = getEncoding(EncodingName, NumLanguages); + + return std::make_unique(std::move(Encoding), NumLanguages, + FinalLanguage, FinalTask); +} + +std::unique_ptr +createWhisperTokenizer(const std::optional &Language, + const std::optional &Task) { + return getTokenizer(true, 99, Language, Task); +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/tokenizer.h b/plugins/wasi_nn/MLX/model/whisper/tokenizer.h new file mode 100644 index 00000000..23ecf401 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/tokenizer.h @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +extern const std::vector> LANGUAGES; +extern const std::unordered_map ToLanguageCode; + +std::string findLanguageByCode(const std::string &Code); +bool languageCodeExists(const std::string &Code); + +class Encoding { +public: + Encoding(const std::string &Name, int ExplicitNVocab, + const std::string &PatStr, + const std::unordered_map &MergeableRanks, + const std::unordered_map &SpecialTokens); + + std::vector encode(const std::string &Text) const; + std::string decode(const std::vector &TokenIds) const; + int encodeSingleToken(const std::string &Token) const; + + int EotToken; + std::unordered_set SpecialTokensSet; + +private: + std::string Name; + std::string PatStr; + std::unordered_map MergeableRanks; + std::unordered_map SpecialTokens; + std::unordered_map TokenToString; +}; + +class Tokenizer { +public: + Tokenizer(std::unique_ptr EncodingPtr, int NumLanguages, + const std::optional &Language = std::nullopt, + const std::optional &Task = std::nullopt); + + std::vector encode(const std::string &Text) const; + std::string decode(const std::vector &TokenIds) const; + std::string decodeWithTimestamps(const std::vector &TokenIds) const; + + // Cached properties + int getEot() const; + int getTranscribe() const; + int getTranslate() const; + int getSot() const; + int getSotLm() const; + int getSotPrev() const; + int getNoSpeech() const; + int getNoTimestamps() const; + int getTimestampBegin() const; + int languageToken() const; + int toLanguageToken(const std::string &LanguageCode) const; + std::vector getAllLanguageTokens() const; + std::vector getAllLanguageCodes() const; + std::vector getSotSequenceIncludingNotimestamps() const; + std::vector getNonSpeechTokens() const; + + std::pair, std::vector>> + splitToWordTokens(const std::vector &Tokens) const; + + // Public members + std::unique_ptr EncodingPtr; + int NumLanguages; + std::optional Language; + std::optional Task; + std::vector SotSequence; + std::unordered_map SpecialTokens; + +private: + std::pair, std::vector>> + splitTokensOnUnicode(const std::vector &Tokens) const; + + std::pair, std::vector>> + splitTokensOnSpaces(const std::vector &Tokens) const; + + // Cached values + mutable std::optional CachedEot; + mutable std::optional CachedTranscribe; + mutable std::optional CachedTranslate; + mutable std::optional CachedSot; + mutable std::optional CachedSotLm; + mutable std::optional CachedSotPrev; + mutable std::optional CachedNoSpeech; + mutable std::optional CachedNoTimestamps; + mutable std::optional CachedTimestampBegin; + mutable std::optional> CachedAllLanguageTokens; + mutable std::optional> CachedAllLanguageCodes; + mutable std::optional> + CachedSotSequenceIncludingNotimestamps; + mutable std::optional> CachedNonSpeechTokens; +}; + +std::unique_ptr getEncoding(const std::string &Name = "gpt2", + int NumLanguages = 99); + +std::unique_ptr +getTokenizer(bool Multilingual, int NumLanguages = 99, + const std::optional &Language = std::nullopt, + const std::optional &Task = std::nullopt); + +// Helper function to create a tokenizer for whisper transcription +std::unique_ptr createWhisperTokenizer( + const std::optional &Language = std::nullopt, + const std::optional &Task = std::nullopt); + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/whisper.cpp b/plugins/wasi_nn/MLX/model/whisper/whisper.cpp new file mode 100644 index 00000000..ccc79190 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/whisper.cpp @@ -0,0 +1,461 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "whisper.h" +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/convolution.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/transformer.h" +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +namespace whisper { + +mx::array sinusoids(int Length, int Channels, float MaxTimescale) { + assert(Channels % 2 == 0); + float LogTimescaleIncrement = std::log(MaxTimescale) / (Channels / 2 - 1); + mx::array InvTimescales = + mx::exp(-LogTimescaleIncrement * mx::arange(Channels / 2)); + mx::array LengthArray = mx::arange(Length); + mx::array LengthReshaped = reshape(LengthArray, {Length, 1}); + mx::array InvTimescalesReshaped = reshape(InvTimescales, {1, -1}); + mx::array ScaledTime = LengthReshaped * InvTimescalesReshaped; + return mx::concatenate({mx::sin(ScaledTime), mx::cos(ScaledTime)}, + /*axis=*/1); +} + +ModelDimensions ModelDimensions::fromDict(const simdjson::dom::object &Obj) { + ModelDimensions Dims; + if (auto Val = Obj["n_mels"]; !Val.error()) { + Dims.NMels = Val.get(); + } + if (auto Val = Obj["n_audio_ctx"]; !Val.error()) { + Dims.NAudioCtx = Val.get(); + } + if (auto Val = Obj["n_audio_state"]; !Val.error()) { + Dims.NAudioState = Val.get(); + } + if (auto Val = Obj["n_audio_head"]; !Val.error()) { + Dims.NAudioHead = Val.get(); + } + if (auto Val = Obj["n_audio_layer"]; !Val.error()) { + Dims.NAudioLayer = Val.get(); + } + if (auto Val = Obj["n_vocab"]; !Val.error()) { + Dims.NVocab = Val.get(); + } + if (auto Val = Obj["n_text_ctx"]; !Val.error()) { + Dims.NTextCtx = Val.get(); + } + if (auto Val = Obj["n_text_state"]; !Val.error()) { + Dims.NTextState = Val.get(); + } + if (auto Val = Obj["n_text_head"]; !Val.error()) { + Dims.NTextHead = Val.get(); + } + if (auto Val = Obj["n_text_layer"]; !Val.error()) { + Dims.NTextLayer = Val.get(); + } + return Dims; +} + +// MultiHeadAttention implementation +MultiHeadAttention::MultiHeadAttention(int NState, int NHead) : NHead(NHead) { + auto Query = std::make_shared(NState, NState); + auto Key = std::make_shared(NState, NState, /*bias=*/false); + auto Value = std::make_shared(NState, NState); + auto Out = std::make_shared(NState, NState); + + registerModule("query", Query); + registerModule("key", Key); + registerModule("value", Value); + registerModule("out", Out); +} + +std::tuple, mx::array> +MultiHeadAttention::forward( + const mx::array &X, const std::optional &Xa, + const std::optional &Mask, + const std::optional> &KvCache) { + + auto Query = std::dynamic_pointer_cast(Submodules["query"]); + auto Key = std::dynamic_pointer_cast(Submodules["key"]); + auto Value = std::dynamic_pointer_cast(Submodules["value"]); + auto Out = std::dynamic_pointer_cast(Submodules["out"]); + + mx::array Q = Query->forward(X); + mx::array K = mx::array({}), V = mx::array({}); + + if (!Xa.has_value()) { + // Self-attention + K = Key->forward(X); + V = Value->forward(X); + if (KvCache.has_value()) { + K = mx::concatenate({KvCache->first, K}, 1); + V = mx::concatenate({KvCache->second, V}, 1); + } + } else if (!KvCache.has_value()) { + // Cross-attention without cache + K = Key->forward(*Xa); + V = Value->forward(*Xa); + } else { + // Cross-attention with cache + K = KvCache->first; + V = KvCache->second; + } + auto [Wv, Qk] = qkvAttention(Q, K, V, Mask); + return {Out->forward(Wv), {K, V}, Qk}; +} + +std::pair +MultiHeadAttention::qkvAttention(const mx::array &Q, const mx::array &K, + const mx::array &V, + const std::optional &Mask) { + + auto Shape = Q.shape(); + int NBatch = Shape[0]; + int NCtx = Shape[1]; + int NState = Shape[2]; + + double Scale = std::pow(NState / NHead, -0.25); + Scale = std::round(Scale * 1000000) / 1000000; + + // Reshape and transpose for multi-head attention + mx::array QReshaped = reshape(Q, {Q.shape(0), Q.shape(1), NHead, -1}); + QReshaped = transpose(QReshaped, {0, 2, 1, 3}) * Scale; + + mx::array KReshaped = reshape(K, {K.shape(0), K.shape(1), NHead, -1}); + KReshaped = transpose(KReshaped, {0, 2, 3, 1}) * Scale; + + mx::array VReshaped = reshape(V, {V.shape(0), V.shape(1), NHead, -1}); + VReshaped = transpose(VReshaped, {0, 2, 1, 3}); + mx::array Qk = mx::matmul(QReshaped, KReshaped); + + if (Mask.has_value()) { + Qk = Qk + slice(*Mask, {0, 0}, {NCtx, NCtx}); + } + mx::array W = mx::softmax(Qk, /*axis=*/-1, /*precise=*/true); + mx::array Out = transpose(mx::matmul(W, VReshaped), {0, 2, 1, 3}); + Out = reshape(Out, {NBatch, NCtx, NState}); + return {Out, Qk}; +} + +// ResidualAttentionBlock implementation +ResidualAttentionBlock::ResidualAttentionBlock(int NState, int NHead, + bool CrossAttention) + : HasCrossAttention(CrossAttention) { + + auto Attn = std::make_shared(NState, NHead); + auto AttnLn = std::make_shared(NState); + + registerModule("attn", Attn); + registerModule("attn_ln", AttnLn); + + if (CrossAttention) { + auto CrossAttn = std::make_shared(NState, NHead); + auto CrossAttnLn = std::make_shared(NState); + registerModule("cross_attn", CrossAttn); + registerModule("cross_attn_ln", CrossAttnLn); + } + + int NMlp = NState * 4; + auto Mlp1 = std::make_shared(NState, NMlp); + auto Mlp2 = std::make_shared(NMlp, NState); + auto MlpLn = std::make_shared(NState); + + registerModule("mlp1", Mlp1); + registerModule("mlp2", Mlp2); + registerModule("mlp_ln", MlpLn); +} + +std::tuple>, + std::optional>>, + std::optional> +ResidualAttentionBlock::forward( + const mx::array &X, const std::optional &Xa, + const std::optional &Mask, + const std::optional< + std::pair>, + std::optional>>> &KvCache) { + + auto Attn = std::dynamic_pointer_cast(Submodules["attn"]); + auto AttnLn = std::dynamic_pointer_cast(Submodules["attn_ln"]); + auto Mlp1 = std::dynamic_pointer_cast(Submodules["mlp1"]); + auto Mlp2 = std::dynamic_pointer_cast(Submodules["mlp2"]); + auto MlpLn = std::dynamic_pointer_cast(Submodules["mlp_ln"]); + + std::optional> Kv = + KvCache ? KvCache->first : std::nullopt; + std::optional> CrossKv = + KvCache ? KvCache->second : std::nullopt; + auto [Y, NewKv, _] = + Attn->forward(AttnLn->forward(X), std::nullopt, Mask, Kv); + mx::array Result = X + Y; + + std::optional CrossQk = std::nullopt; + std::optional> NewCrossKv = std::nullopt; + + if (HasCrossAttention) { + auto CrossAttn = + std::dynamic_pointer_cast(Submodules["cross_attn"]); + auto CrossAttnLn = + std::dynamic_pointer_cast(Submodules["cross_attn_ln"]); + + auto [CrossY, TempCrossKv, TempCrossQk] = CrossAttn->forward( + CrossAttnLn->forward(Result), Xa, std::nullopt, CrossKv); + Result = Result + CrossY; + NewCrossKv = TempCrossKv; + CrossQk = TempCrossQk; + } + Result = Result + Mlp2->forward( + mlx::core::gelu(Mlp1->forward(MlpLn->forward(Result)))); + return {Result, {NewKv, NewCrossKv}, CrossQk}; +} + +// AudioEncoder implementation +AudioEncoder::AudioEncoder(int NMels, int NCtx, int NState, int NHead, + int NLayer, mx::Dtype Dtype) { + auto Conv1 = std::make_shared(NMels, NState, 3, 1, 1); + auto Conv2 = std::make_shared(NState, NState, 3, 2, 1); + + registerModule("conv1", Conv1); + registerModule("conv2", Conv2); + + PositionalEmbedding = astype(sinusoids(NCtx, NState), Dtype); + + for (int I = 0; I < NLayer; ++I) { + auto Block = std::make_shared(NState, NHead); + Blocks.push_back(Block); + } + registerLayer("blocks", Blocks); + + auto LnPost = std::make_shared(NState); + registerModule("ln_post", LnPost); +} + +mx::array AudioEncoder::forward(const mx::array &X) { + auto Conv1 = std::dynamic_pointer_cast(Submodules["conv1"]); + auto Conv2 = std::dynamic_pointer_cast(Submodules["conv2"]); + auto LnPost = std::dynamic_pointer_cast(Submodules["ln_post"]); + mx::array Result = Conv1->forward(X); + Result = mlx::core::gelu(Result); + Result = Conv2->forward(Result); + Result = mlx::core::gelu(Result); + assert(Result.shape()[1] == PositionalEmbedding.shape()[0] && + Result.shape()[2] == PositionalEmbedding.shape()[1]); + + Result = Result + PositionalEmbedding; + for (auto &Block : Blocks) { + auto [NewResult, _, __] = Block->forward(Result); + Result = NewResult; + } + return LnPost->forward(Result); +} + +// TextDecoder implementation +TextDecoder::TextDecoder(int NVocab, int NCtx, int NState, int NHead, + int NLayer, mx::Dtype Dtype) { + auto TokenEmbedding = std::make_shared(NVocab, NState); + registerModule("token_embedding", TokenEmbedding); + + registerParameter("positional_embedding", mx::zeros({NCtx, NState})); + + for (int I = 0; I < NLayer; ++I) { + auto Block = std::make_shared( + NState, NHead, /*cross_attention=*/true); + Blocks.push_back(Block); + } + registerLayer("blocks", Blocks); + + auto Ln = std::make_shared(NState); + registerModule("ln", Ln); + + Mask = astype(nn::MultiHeadAttention::createAdditiveCausalMask(NCtx), Dtype); +} + +std::tuple< + mx::array, + std::vector>, + std::optional>>>, + std::vector>> +TextDecoder::forward( + const mx::array &X, const mx::array &Xa, + const std::optional< + std::vector>, + std::optional>>>> + &KvCache) { + + auto TokenEmbedding = + std::dynamic_pointer_cast(Submodules["token_embedding"]); + auto Ln = std::dynamic_pointer_cast(Submodules["ln"]); + + int Offset = 0; + if (KvCache.has_value() && !KvCache->empty() && + KvCache->at(0).first.has_value() && + KvCache->at(0).first->first.shape(1) > 0) { + Offset = KvCache->at(0).first->first.shape(1); + } + std::vector Start(Parameters.at("positional_embedding").shape().size(), + 0); + std::vector End = Parameters.at("positional_embedding").shape(); + Start[0] = Offset; + End[0] = Offset + X.shape(-1); + + mx::array Result = TokenEmbedding->forward(X) + + slice(Parameters.at("positional_embedding"), Start, End); + + std::vector>, + std::optional>>> + NewKvCache; + std::vector> CrossQk; + + if (!KvCache.has_value()) { + NewKvCache.resize(Blocks.size()); + for (auto &Item : NewKvCache) { + Item = {std::nullopt, std::nullopt}; + } + } else { + NewKvCache = *KvCache; + } + + CrossQk.resize(Blocks.size()); + + for (size_t I = 0; I < Blocks.size(); ++I) { + auto [NewResult, UpdatedCache, BlockCrossQk] = + Blocks[I]->forward(Result, Xa, Mask, NewKvCache[I]); + Result = NewResult; + NewKvCache[I] = UpdatedCache; + CrossQk[I] = BlockCrossQk; + } + Result = Ln->forward(Result); + mx::array Logits = TokenEmbedding->asLinear(Result); + + return {Logits, NewKvCache, CrossQk}; +} + +Whisper::Whisper(const ModelDimensions &Dims, mx::Dtype Dtype) : Dims(Dims) { + auto Encoder = std::make_shared( + Dims.NMels, Dims.NAudioCtx, Dims.NAudioState, Dims.NAudioHead, + Dims.NAudioLayer, Dtype); + + auto Decoder = + std::make_shared(Dims.NVocab, Dims.NTextCtx, Dims.NTextState, + Dims.NTextHead, Dims.NTextLayer, Dtype); + + registerModule("encoder", Encoder); + registerModule("decoder", Decoder); + registerParameter("alignment_heads", mx::array({})); + // // Initialize alignment heads (use last half of decoder layers by default) + // std::vector> AllHeads( + // Dims.NTextLayer, std::vector(Dims.NTextHead, false)); + // for (int I = Dims.NTextLayer / 2; I < Dims.NTextLayer; ++I) { + // std::fill(AllHeads[I].begin(), AllHeads[I].end(), true); + // } + + // // Find all True positions and create the alignment heads array + // // Equivalent to: self.alignment_heads = + // // mx.array(np.asarray(all_heads.nonzero()).T) + // std::vector> NonzeroIndices; + // for (int Layer = 0; Layer < Dims.NTextLayer; ++Layer) { + // for (int Head = 0; Head < Dims.NTextHead; ++Head) { + // if (AllHeads[Layer][Head]) { + // NonzeroIndices.push_back({Layer, Head}); + // } + // } + // } + + // // Convert to mx::array format [N, 2] where N is number of True positions + // if (!NonzeroIndices.empty()) { + // std::vector FlatIndices; + // FlatIndices.reserve(NonzeroIndices.size() * 2); + // for (const auto &Index : NonzeroIndices) { + // FlatIndices.push_back(Index[0]); // Layer index + // FlatIndices.push_back(Index[1]); // Head index + // } + // AlignmentHeads = + // mx::array(FlatIndices.data(), + // {static_cast(NonzeroIndices.size()), 2}, mx::int64); + // } else { + // // Create empty array with correct shape + // AlignmentHeads = mx::zeros({0, 2}, mx::int64); + // } +} + +mx::array Whisper::forward(const mx::array &Mel, const mx::array &Tokens) { + auto Encoder = std::dynamic_pointer_cast(Submodules["encoder"]); + auto Decoder = std::dynamic_pointer_cast(Submodules["decoder"]); + + mx::array AudioFeatures = Encoder->forward(Mel); + auto [Logits, _, __] = Decoder->forward(Tokens, AudioFeatures); + return Logits; +} + +mx::array Whisper::embedAudio(const mx::array &Mel) { + auto Encoder = std::dynamic_pointer_cast(Submodules["encoder"]); + return Encoder->forward(Mel); +} + +mx::array Whisper::logits(const mx::array &Tokens, + const mx::array &AudioFeatures) { + auto Decoder = std::dynamic_pointer_cast(Submodules["decoder"]); + auto [Logits, _, __] = Decoder->forward(Tokens, AudioFeatures); + return Logits; +} + +std::pair>> +Whisper::forwardWithCrossQk(const mx::array &Mel, const mx::array &Tokens) { + auto Encoder = std::dynamic_pointer_cast(Submodules["encoder"]); + auto Decoder = std::dynamic_pointer_cast(Submodules["decoder"]); + + mx::array AudioFeatures = Encoder->forward(Mel); + auto [Logits, _, CrossQk] = Decoder->forward(Tokens, AudioFeatures); + return {Logits, CrossQk}; +} + +bool Whisper::isMultilingual() const { return Dims.NVocab >= 51865; } + +int Whisper::numLanguages() const { + return Dims.NVocab - 51765 - static_cast(isMultilingual()); +} + +std::shared_ptr Whisper::fromPretrained(const std::string &ModelPath) { + std::filesystem::path Path(ModelPath); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto Error = Parser.load((Path / "config.json").string()).get(Doc); + if (Error) { + spdlog::error("Could not open config.json at {}", Path.string()); + assumingUnreachable(); + } + auto Obj = Doc.get_object(); + ModelDimensions DefaultDims = ModelDimensions::fromDict(Obj.value()); + auto Model = std::make_shared(DefaultDims); + std::vector WeightFiles; + for (auto &P : std::filesystem::directory_iterator(Path)) { + if (P.path().extension() == ".safetensors") + WeightFiles.push_back(P.path()); + } + if (WeightFiles.empty()) { + spdlog::error("No safetensors found in {}.", Path.string()); + assumingUnreachable(); + } + std::unordered_map Weights; + for (auto &Wf : WeightFiles) { + auto W = mx::load_safetensors(Wf.string()); + Weights.insert(W.first.begin(), W.first.end()); + } + Model->update(Weights); + + return Model; +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/whisper.h b/plugins/wasi_nn/MLX/model/whisper/whisper.h new file mode 100644 index 00000000..2587a37f --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/whisper.h @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "simdjson.h" +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +namespace nn = mlx::core::nn; + +namespace whisper { + +struct ModelDimensions { + int NMels = 80; + int NAudioCtx = 1500; + int NAudioState = 768; + int NAudioHead = 12; + int NAudioLayer = 12; + int NVocab = 51864; + int NTextCtx = 448; + int NTextState = 768; + int NTextHead = 12; + int NTextLayer = 12; + + static ModelDimensions fromDict(const simdjson::dom::object &Obj); +}; + +// Utility function for positional embeddings +mx::array sinusoids(int Length, int Channels, float MaxTimescale = 10000.0f); + +class MultiHeadAttention : public nn::Module { +public: + MultiHeadAttention(int NState, int NHead); + + std::tuple, mx::array> + forward(const mx::array &X, const std::optional &Xa = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional> &KvCache = + std::nullopt); + +private: + std::pair + qkvAttention(const mx::array &Q, const mx::array &K, const mx::array &V, + const std::optional &Mask = std::nullopt); + + int NHead; +}; + +class ResidualAttentionBlock : public nn::Module { +public: + ResidualAttentionBlock(int NState, int NHead, bool CrossAttention = false); + + std::tuple>, + std::optional>>, + std::optional> + forward(const mx::array &X, const std::optional &Xa = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional< + std::pair>, + std::optional>>> + &KvCache = std::nullopt); + +private: + bool HasCrossAttention; +}; + +class AudioEncoder : public nn::Module { +public: + AudioEncoder(int NMels, int NCtx, int NState, int NHead, int NLayer, + mx::Dtype Dtype = mx::float16); + + mx::array forward(const mx::array &X); + +private: + mx::array PositionalEmbedding = mx::array({}); + std::vector> Blocks; +}; + +class TextDecoder : public nn::Module { +public: + TextDecoder(int NVocab, int NCtx, int NState, int NHead, int NLayer, + mx::Dtype Dtype = mx::float16); + + std::tuple< + mx::array, + std::vector>, + std::optional>>>, + std::vector>> + forward(const mx::array &X, const mx::array &Xa, + const std::optional>, + std::optional>>>> + &KvCache = std::nullopt); + +private: + mx::array Mask = mx::array({}); + std::vector> Blocks; +}; + +class Whisper : public nn::Module { +public: + Whisper(const ModelDimensions &Dims, mx::Dtype Dtype = mx::float16); + + mx::array forward(const mx::array &Mel, const mx::array &Tokens); + + mx::array embedAudio(const mx::array &Mel); + mx::array logits(const mx::array &Tokens, const mx::array &AudioFeatures); + std::pair>> + forwardWithCrossQk(const mx::array &Mel, const mx::array &Tokens); + + bool isMultilingual() const; + int numLanguages() const; + + static std::shared_ptr fromPretrained(const std::string &ModelPath); + + ModelDimensions Dims; + mx::array AlignmentHeads = mx::array({}); +}; + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp new file mode 100644 index 00000000..546952f6 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp @@ -0,0 +1,819 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "whisper_transcribe.h" +#include "mlx/base.h" +#include "spdlog/spdlog.h" +#include "whisper/tokenizer.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +mx::array loadAudio(const std::string &FilePath, int SampleRate) { + int Channels = 1; + std::vector Cmd = {"ffmpeg", "-nostdin", + "-threads", "0", + "-i", FilePath, + "-f", "s16le", + "-ac", std::to_string(Channels), + "-acodec", "pcm_s16le", + "-ar", std::to_string(SampleRate), + "-v", "quiet", + "-"}; + + // Build command string + std::stringstream CmdStream; + for (size_t i = 0; i < Cmd.size(); ++i) { + if (i > 0) + CmdStream << " "; + CmdStream << Cmd[i]; + } + std::string CmdString = CmdStream.str(); + + // Execute ffmpeg command + FILE *Pipe = popen(CmdString.c_str(), "r"); + if (!Pipe) { + throw std::runtime_error("Failed to execute ffmpeg command"); + } + + // Read raw audio data + std::vector AudioData; + int16_t Buffer[4096]; + size_t BytesRead; + + while ((BytesRead = fread(Buffer, sizeof(int16_t), 4096, Pipe)) > 0) { + AudioData.insert(AudioData.end(), Buffer, Buffer + BytesRead); + } + + int ExitCode = pclose(Pipe); + if (ExitCode != 0) { + throw std::runtime_error("ffmpeg command failed with exit code: " + + std::to_string(ExitCode)); + } + + if (AudioData.empty()) { + throw std::runtime_error("No audio data loaded from file: " + FilePath); + } + + mx::array AudioArray = mx::array( + AudioData.data(), {static_cast(AudioData.size())}, mx::int16); + + if (Channels > 1) { + int Frames = AudioData.size() / Channels; + AudioArray = mx::reshape(AudioArray, {Frames, Channels}); + } + + mx::array FloatAudio = mx::astype(AudioArray, mx::float32) / 32768.0f; + + return FloatAudio; +} + +mx::array padOrTrim(const mx::array &Array, int Length, int Axis) { + auto Shape = Array.shape(); + int ActualAxis = Axis < 0 ? Shape.size() + Axis : Axis; + + if (Shape[ActualAxis] > Length) { + std::vector Start(Shape.size(), 0); + std::vector End = Shape; + std::vector Strides(Shape.size(), 1); + End[ActualAxis] = Length; + return mx::slice(Array, Start, End, Strides); + } + + if (Shape[ActualAxis] < Length) { + std::vector> PadWidths(Shape.size(), {0, 0}); + PadWidths[ActualAxis] = {0, Length - Shape[ActualAxis]}; + return mx::pad(Array, PadWidths); + } + + return Array; +} + +mx::array hanningWindow(int Size) { + mx::array N = mx::arange(Size + 1); + mx::array Window = 0.5f * (1.0f - mx::cos(2.0f * M_PI * N / Size)); + std::vector Start = {0}; + std::vector End = {-1}; + std::vector Strides = {1}; + return mx::slice(Window, Start, End, Strides); +} + +mx::array stft(const mx::array &X, const mx::array &Window, int NPerseg = 256, + int NOverlap = 0, int NfFt = 0, + const std::string &PadMode = "reflect") { + if (NfFt == 0) + NfFt = NPerseg; + + auto Pad = [](const mx::array &X, int Padding, const std::string &PadMode) { + if (PadMode == "constant") { + std::vector> PadWidths = {{Padding, Padding}}; + return mx::pad(X, PadWidths); + } + if (PadMode == "reflect") { + // x[1:padding+1][::-1] + std::vector Start(X.shape().size(), 0); + std::vector End = X.shape(); + std::vector Strides(X.shape().size(), 1); + + Start[0] = 1; + End[0] = Padding + 1; + auto Prefix = mx::slice(X, Start, End, Strides); + Start[0] = -(Padding + 1); + End[0] = -1; + auto Suffix = mx::slice(X, Start, End, Strides); + + return mx::concatenate({Prefix, X, Suffix}); + } + throw std::runtime_error("Invalid pad mode: " + PadMode); + }; + + // Pad the signal + int Padding = NPerseg / 2; + auto Padded = Pad(X, Padding, PadMode); + std::vector Strides = {static_cast(NOverlap), 1}; + int T = (Padded.shape(0) - NPerseg + NOverlap) / NOverlap; + auto Shape = std::vector{T, NfFt}; + auto StridedX = mx::as_strided(Padded, Shape, Strides, 0); + + return mx::fft::rfft(StridedX * Window); +} + +mx::array melFilters(int NMels) { + // Load precomputed mel filters from a file. + if (NMels != 80 && NMels != 128) { + spdlog::error("Unsupported number of mel filters: " + + std::to_string(NMels)); + assumingUnreachable(); + } + std::string FileName = "assets/mel_filters_" + std::to_string(NMels) + ".npy"; + return mx::load(FileName); +} + +mx::array logMelSpectrogram(const mx::array &Audio, int NMels, int Padding) { + auto PaddedAudio = Audio; + if (Padding > 0) { + std::vector> PadWidths = {{0, Padding}}; + PaddedAudio = mx::pad(Audio, PadWidths); + } + + auto Window = hanningWindow(DefaultNFft); + auto Freqs = stft(PaddedAudio, Window, DefaultNFft, DefaultHopLength, + DefaultNFft, "reflect"); + // freqs[:-1, :].abs().square() + std::vector Start = {0, 0}; + std::vector End = {static_cast(Freqs.shape(0) - 1), + static_cast(Freqs.shape(1))}; + std::vector Strides = {1, 1}; + auto FreqsSliced = mx::slice(Freqs, Start, End, Strides); + auto Magnitudes = mx::square(mx::abs(FreqsSliced)); + // Apply mel filters + auto Filters = melFilters(NMels); + auto MelSpec = mx::matmul(Magnitudes, mx::transpose(Filters)); + auto LogSpec = + mx::log10(mx::maximum(MelSpec, mx::full(MelSpec.shape(), 1e-10f))); + LogSpec = mx::maximum(LogSpec, mx::max(LogSpec) - 8.0f); + LogSpec = (LogSpec + 4.0f) / 4.0f; + return LogSpec; +} + +// Utility functions +std::string formatTimestamp(float Seconds) { + assert(Seconds >= 0); + int Milliseconds = static_cast(std::round(Seconds * 1000.0f)); + + int Hours = Milliseconds / 3600000; + Milliseconds -= Hours * 3600000; + + int Minutes = Milliseconds / 60000; + Milliseconds -= Minutes * 60000; + + int Secs = Milliseconds / 1000; + Milliseconds -= Secs * 1000; + + std::stringstream Ss; + if (Hours > 0) { + Ss << std::setfill('0') << std::setw(2) << Hours << ":"; + } + Ss << std::setfill('0') << std::setw(2) << Minutes << ":" << std::setfill('0') + << std::setw(2) << Secs << "." << std::setfill('0') << std::setw(3) + << Milliseconds; + + return Ss.str(); +} + +std::optional getEnd(const std::vector &Segments) { + for (auto It = Segments.rbegin(); It != Segments.rend(); ++It) { + for (auto WordIt = It->Words.rbegin(); WordIt != It->Words.rend(); + ++WordIt) { + return WordIt->End; + } + if (!It->Words.empty()) { + return It->End; + } + } + return Segments.empty() ? std::nullopt + : std::make_optional(Segments.back().End); +} + +// Decoding functions +DecodingResult +decodeWithFallback(std::shared_ptr Model, + const mx::array &MelSegment, + const DecodingOptions &DecodeOptions, + const std::vector &Temperatures, + std::unique_ptr &Tokenizer, + std::optional CompressionRatioThreshold, + std::optional LogprobThreshold, + std::optional NoSpeechThreshold) { + + DecodingResult Result; + Result.AudioFeatures = MelSegment; + Result.Language = ""; + Result.Tokens = std::vector(); + Result.Text = ""; + + for (float Temp : Temperatures) { + // Simple greedy decoding for temperature = 0, sampling otherwise + DecodingOptions Options = DecodeOptions; + Options.Temperature = Temp; + + // This is a simplified decode implementation + // In practice, you'd implement the full beam search/sampling logic + std::vector Tokens; + + // Start with SOT sequence + auto SotSequence = Tokenizer->SotSequence; + Tokens.insert(Tokens.end(), SotSequence.begin(), SotSequence.end()); + + // Generate tokens + mx::array CurrentTokens = mx::array( + Tokens.data(), {1, static_cast(Tokens.size())}, mx::int32); + Result = std::get(decode(Model, MelSegment, Options)); + + // Check fallback conditions + bool NeedsFallback = false; + if (CompressionRatioThreshold && + Result.CompressionRatio > *CompressionRatioThreshold) { + NeedsFallback = true; + } + if (LogprobThreshold && Result.AvgLogprob < *LogprobThreshold) { + NeedsFallback = true; + } + if (NoSpeechThreshold && Result.NoSpeechProb > *NoSpeechThreshold) { + NeedsFallback = false; + } + + if (!NeedsFallback) { + break; + } + } + + return Result; +} + +// Word-level timestamp functions. +void addWordTimestamps(std::vector &Segments, + std::shared_ptr Model, + std::unique_ptr &Tokenizer, + const mx::array &Mel, int NumFrames, + const std::string &PrependPunctuations, + const std::string &AppendPunctuations, + float LastSpeechTimestamp) { + + // This is a simplified implementation. + // A full implementation would use cross-attention patterns and DTW. + for (auto &Segment : Segments) { + if (!Segment.Tokens.empty()) { + float Duration = Segment.End - Segment.Start; + float TimePerToken = Duration / Segment.Tokens.size(); + + // Simple uniform distribution of word timestamps + std::string Text = Segment.Text; + std::istringstream Iss(Text); + std::string Word; + int WordIdx = 0; + + while (Iss >> Word && + static_cast(WordIdx) < Segment.Tokens.size()) { + WordInfo WordInfo; + WordInfo.Word = Word; + WordInfo.Start = Segment.Start + WordIdx * TimePerToken; + WordInfo.End = Segment.Start + (WordIdx + 1) * TimePerToken; + WordInfo.Probability = 0.8f; // Default probability + + Segment.Words.push_back(WordInfo); + WordIdx++; + } + } + } +} + +// Anomaly detection functions +float wordAnomalyScore(const WordInfo &Word) { + float Score = 0.0f; + + if (Word.Probability < 0.15f) { + Score += 1.0f; + } + + float Duration = Word.End - Word.Start; + if (Duration < 0.133f) { + Score += (0.133f - Duration) * 15.0f; + } + if (Duration > 2.0f) { + Score += Duration - 2.0f; + } + + return Score; +} + +bool isSegmentAnomaly(const std::optional &Segment) { + if (!Segment || Segment->Words.empty()) { + return false; + } + + // Filter out punctuation words + std::vector FilteredWords; + std::string Punctuation = "\"'?([{-\"'.,!?:\")]},"; + + for (const auto &Word : Segment->Words) { + if (Punctuation.find(Word.Word) == std::string::npos) { + FilteredWords.push_back(Word); + } + } + + if (FilteredWords.size() > 8) { + FilteredWords.resize(8); + } + + float TotalScore = 0.0f; + for (const auto &Word : FilteredWords) { + TotalScore += wordAnomalyScore(Word); + } + + return TotalScore >= 3.0f || TotalScore + 0.01f >= FilteredWords.size(); +} + +std::optional +nextWordsSegment(const std::vector &Segments) { + for (const auto &Segment : Segments) { + if (!Segment.Words.empty()) { + return Segment; + } + } + return std::nullopt; +} + +// Main transcribe function +TranscribeResult +transcribe(const std::variant &Audio, + std::shared_ptr Model, std::optional Verbose, + std::variant> Temperature, + std::optional CompressionRatioThreshold, + std::optional LogprobThreshold, + std::optional NoSpeechThreshold, bool ConditionOnPreviousText, + std::optional InitialPrompt, bool WordTimestamps, + const std::string &PrependPunctuations, + const std::string &AppendPunctuations, + std::variant> ClipTimestamps, + std::optional HallucinationSilenceThreshold, + const DecodingOptions &DecodeOptions) { + + // Get audio array + mx::array AudioArray = + std::holds_alternative(Audio) + ? loadAudio(std::get(Audio), DefaultSampleRate) + : std::get(Audio); + + // Get dtype + mx::Dtype Dtype = DecodeOptions.Fp16 ? mx::float16 : mx::float32; + + // Generate mel spectrogram + auto Mel = logMelSpectrogram(AudioArray, Model->Dims.NMels, DefaultNSamples); + int ContentFrames = Mel.shape(-2) - DefaultNFrames; + float ContentDuration = + static_cast(ContentFrames * DefaultHopLength) / DefaultSampleRate; + + // Language detection + DecodingOptions Options = DecodeOptions; + if (!Options.Language) { + if (!Model->isMultilingual()) { + Options.Language = "en"; + } else { + if (Verbose.value_or(false)) { + std::cout << "Detecting language using up to the first 30 seconds.\n"; + } + + auto MelSegment = padOrTrim(Mel, DefaultNFrames, -2); + MelSegment = mx::astype(MelSegment, Dtype); + auto [LangTokens, LangProbs] = detectLanguage(Model, MelSegment); + + // Find most probable language from the detection results + if (!LangProbs.empty() && !LangProbs[0].empty()) { + // Find language with highest probability + auto MaxIterator = std::max_element( + LangProbs[0].begin(), LangProbs[0].end(), + [](const auto &A, const auto &B) { return A.second < B.second; }); + Options.Language = MaxIterator->first; + } else { + Options.Language = "en"; // Default fallback only if detection failed + } + + if (Verbose.value_or(false)) { + std::string LangName = findLanguageByCode(*Options.Language); + if (LangName.empty()) { + LangName = *Options.Language; // fallback to code if not found + } + std::cout << "Detected language: " << LangName << std::endl; + } + } + } + + auto Task = Options.Task; + auto Tokenizer = getTokenizer(Model->isMultilingual(), Model->numLanguages(), + *Options.Language, Task); + + // Parse clip timestamps + std::vector ClipTimes; + if (std::holds_alternative(ClipTimestamps)) { + std::string ClipStr = std::get(ClipTimestamps); + if (!ClipStr.empty()) { + std::stringstream Ss(ClipStr); + std::string Item; + while (std::getline(Ss, Item, ',')) { + ClipTimes.push_back(std::stof(Item)); + } + } + } else { + ClipTimes = std::get>(ClipTimestamps); + } + + // Set up seek points + std::vector SeekPoints; + for (float Ts : ClipTimes) { + SeekPoints.push_back( + static_cast(std::round(Ts * DefaultFramesPerSecond))); + } + if (SeekPoints.empty()) { + SeekPoints.push_back(0); + } + if (SeekPoints.size() % 2 == 1) { + SeekPoints.push_back(ContentFrames); + } else { + SeekPoints.back() = std::min(ContentFrames, SeekPoints.back()); + } + + // Create seek clips + std::vector> SeekClips; + for (size_t I = 0; I < SeekPoints.size(); I += 2) { + SeekClips.emplace_back(SeekPoints[I], SeekPoints[I + 1]); + } + + int Seek = SeekClips[0].first; + // time_precision + + std::vector AllTokens; + std::vector AllSegments; + // prompt_reset_since + int PromptResetSince = 0; + + if (InitialPrompt) { + auto PromptTokens = Tokenizer->encode(" " + *InitialPrompt); + AllTokens.insert(AllTokens.end(), PromptTokens.begin(), PromptTokens.end()); + } + + std::vector Temperatures; + if (std::holds_alternative(Temperature)) { + Temperatures = {std::get(Temperature)}; + } else { + Temperatures = std::get>(Temperature); + } + + const int InputStride = DefaultNFrames / Model->Dims.NAudioCtx; // 2 for tiny + const float TimePrecision = + static_cast(InputStride * DefaultHopLength) / DefaultSampleRate; + const int FramesPerSecond = DefaultFramesPerSecond; + + // Processing loop + float LastSpeechTimestamp = 0.0f; + for (const auto &[SeekClipStart, SeekClipEnd] : SeekClips) { + while (Seek < SeekClipEnd) { + float TimeOffset = + static_cast(Seek * DefaultHopLength) / DefaultSampleRate; + float WindowEndTime = + static_cast((Seek + DefaultNFrames) * DefaultHopLength) / + DefaultSampleRate; + int SegmentSize = + std::min({DefaultNFrames, ContentFrames - Seek, SeekClipEnd - Seek}); + + // Extract mel segment + std::vector Start(Mel.shape().size(), 0); + std::vector End = Mel.shape(); + Start[0] = Seek; + End[0] = Seek + SegmentSize; + auto MelSegment = mx::slice(Mel, Start, End); + MelSegment = padOrTrim(MelSegment, DefaultNFrames, -2); + MelSegment = mx::astype(MelSegment, Dtype); + + // Decode segment + // Provide prompt tokens since last reset + { + std::vector PromptSlice; + if (PromptResetSince >= 0 && + PromptResetSince <= static_cast(AllTokens.size())) { + PromptSlice.assign(AllTokens.begin() + PromptResetSince, + AllTokens.end()); + } + Options.Prompt = PromptSlice; // empty slice allowed (adds sot_prev) + } + + auto Result = decodeWithFallback(Model, MelSegment, Options, Temperatures, + Tokenizer, CompressionRatioThreshold, + LogprobThreshold, NoSpeechThreshold); + + // Voice activity check and fast-forward if silence + if (NoSpeechThreshold) { + bool ShouldSkip = Result.NoSpeechProb > *NoSpeechThreshold; + if (LogprobThreshold && Result.AvgLogprob > *LogprobThreshold) { + ShouldSkip = false; + } + if (ShouldSkip) { + Seek += SegmentSize; + continue; + } + } + + // Prepare tokens for timestamp analysis + const auto &TokVec = Result.Tokens; + std::vector TimestampMask(TokVec.size(), false); + int TsBegin = Tokenizer->getTimestampBegin(); + for (size_t I = 0; I < TokVec.size(); ++I) { + TimestampMask[I] = TokVec[I] >= TsBegin; + } + + bool SingleTimestampEnding = false; + if (TimestampMask.size() >= 2) { + SingleTimestampEnding = + (!TimestampMask[TimestampMask.size() - 2] && TimestampMask.back()); + } + + // Find consecutive timestamp pairs + std::vector ConsecutiveIdx; + for (int I = 1; I < static_cast(TimestampMask.size()); ++I) { + if (TimestampMask[I - 1] && TimestampMask[I]) { + ConsecutiveIdx.push_back(I); + } + } + + std::vector CurrentSegments; + auto NewSegmentFromSlice = [&](int LastSliceIdx, int CurrentSliceIdx, + float StartBase, float EndBase) { + std::vector SliceTokens; + SliceTokens.insert(SliceTokens.end(), TokVec.begin() + LastSliceIdx, + TokVec.begin() + CurrentSliceIdx); + if (SliceTokens.empty()) + return; // nothing to add + int StartTsPos = SliceTokens.front() - TsBegin; + int EndTsPos = SliceTokens.back() - TsBegin; + TranscribeSegment S; + S.Id = static_cast(AllSegments.size() + CurrentSegments.size()); + S.Seek = Seek; + S.Start = StartBase + StartTsPos * TimePrecision; + S.End = StartBase + EndTsPos * TimePrecision; + // text only from tokens < eot + std::vector TextTokens; + int Eot = Tokenizer->getEot(); + for (int Tok : SliceTokens) { + if (Tok < Eot) + TextTokens.push_back(Tok); + } + S.Text = Tokenizer->decode(TextTokens); + S.Tokens.assign(SliceTokens.begin(), SliceTokens.end()); + S.Temperature = Result.Temperature; + S.AvgLogprob = Result.AvgLogprob; + S.CompressionRatio = Result.CompressionRatio; + S.NoSpeechProb = Result.NoSpeechProb; + CurrentSegments.push_back(std::move(S)); + }; + + if (!ConsecutiveIdx.empty()) { + std::vector Slices = ConsecutiveIdx; + if (SingleTimestampEnding) { + Slices.push_back(static_cast(TokVec.size())); + } + + int LastSlice = 0; + for (int Cur : Slices) { + NewSegmentFromSlice(LastSlice, Cur, TimeOffset, TimeOffset); + LastSlice = Cur; + } + + if (SingleTimestampEnding) { + // no speech after the last timestamp + Seek += SegmentSize; + } else { + int LastTimestampPos = TokVec[Slices.back() - 1] - TsBegin; + Seek += LastTimestampPos * InputStride; + } + } else { + // No consecutive timestamp tokens + float SegmentDuration = + static_cast(SegmentSize * DefaultHopLength) / + DefaultSampleRate; + int LastTsIndex = -1; + for (int I = static_cast(TokVec.size()) - 1; I >= 0; --I) { + if (TimestampMask[I]) { + LastTsIndex = I; + break; + } + } + if (LastTsIndex != -1 && TokVec[LastTsIndex] != TsBegin) { + int LastTsPos = TokVec[LastTsIndex] - TsBegin; + SegmentDuration = LastTsPos * TimePrecision; + } + + // Create one segment for the whole window + TranscribeSegment S; + S.Id = static_cast(AllSegments.size()); + S.Seek = Seek; + S.Start = TimeOffset; + S.End = TimeOffset + SegmentDuration; + // text from tokens < eot + std::vector TextTokens; + int Eot = Tokenizer->getEot(); + for (int Tok : TokVec) { + if (Tok < Eot) + TextTokens.push_back(Tok); + } + S.Text = Tokenizer->decode(TextTokens); + S.Tokens = TokVec; + S.Temperature = Result.Temperature; + S.AvgLogprob = Result.AvgLogprob; + S.CompressionRatio = Result.CompressionRatio; + S.NoSpeechProb = Result.NoSpeechProb; + CurrentSegments.push_back(std::move(S)); + + Seek += SegmentSize; + } + + // Word-level timestamps and hallucination handling + if (WordTimestamps) { + addWordTimestamps(CurrentSegments, Model, Tokenizer, MelSegment, + SegmentSize, PrependPunctuations, AppendPunctuations, + LastSpeechTimestamp); + + if (!SingleTimestampEnding) { + auto LastWordEndOpt = getEnd(CurrentSegments); + if (LastWordEndOpt && *LastWordEndOpt > TimeOffset) { + Seek = static_cast( + std::round((*LastWordEndOpt) * FramesPerSecond)); + } + } + + if (HallucinationSilenceThreshold) { + float Threshold = *HallucinationSilenceThreshold; + if (!SingleTimestampEnding) { + auto LastWordEndOpt = getEnd(CurrentSegments); + if (LastWordEndOpt && *LastWordEndOpt > TimeOffset) { + float Remaining = WindowEndTime - *LastWordEndOpt; + if (Remaining > Threshold) { + Seek = static_cast( + std::round((*LastWordEndOpt) * FramesPerSecond)); + } else { + // keep default seek progression + } + } + } + + // if first segment might be a hallucination, skip leading silence + auto FirstWithWords = nextWordsSegment(CurrentSegments); + if (FirstWithWords && isSegmentAnomaly(FirstWithWords)) { + float Gap = FirstWithWords->Start - TimeOffset; + if (Gap > Threshold) { + int NewSeek = static_cast(std::round( + static_cast(Seek) + Gap * FramesPerSecond)); + Seek = NewSeek; + // restart loop + continue; + } + } + + // skip silence before any possible hallucination surrounded by + // silence + float HalLastEnd = LastSpeechTimestamp; + for (size_t Si = 0; Si < CurrentSegments.size(); ++Si) { + auto &Seg = CurrentSegments[Si]; + if (Seg.Words.empty()) + continue; + std::optional SegOpt = Seg; + if (isSegmentAnomaly(SegOpt)) { + std::optional NextSeg = std::nullopt; + for (size_t J = Si + 1; J < CurrentSegments.size(); ++J) { + if (!CurrentSegments[J].Words.empty()) { + NextSeg = CurrentSegments[J]; + break; + } + } + float HalNextStart = + NextSeg ? NextSeg->Words.front().Start + : TimeOffset + static_cast(SegmentSize * + DefaultHopLength) / + DefaultSampleRate; + bool SilenceBefore = (Seg.Start - HalLastEnd > Threshold) || + (Seg.Start < Threshold) || + (Seg.Start - TimeOffset < 2.0f); + bool SilenceAfter = (HalNextStart - Seg.End > Threshold) || + isSegmentAnomaly(NextSeg) || + (WindowEndTime - Seg.End < 2.0f); + if (SilenceBefore && SilenceAfter) { + Seek = static_cast(std::round( + std::max(TimeOffset + 1.0f, Seg.Start) * FramesPerSecond)); + float RemainingContent = ContentDuration - Seg.End; + if (RemainingContent < Threshold) { + Seek = ContentFrames; + } + // drop subsequent segments + CurrentSegments.resize(Si); + break; + } + } + HalLastEnd = Seg.End; + } + } + + // update last_speech_timestamp + auto LastWordEndOpt = getEnd(CurrentSegments); + if (LastWordEndOpt) { + LastSpeechTimestamp = *LastWordEndOpt; + } + } + + // Verbose output for each segment + if (Verbose.value_or(false)) { + for (const auto &S : CurrentSegments) { + std::cout << "[" << formatTimestamp(S.Start) << " --> " + << formatTimestamp(S.End) << "] " << S.Text << std::endl; + } + } + + // Clean instantaneous or empty segments + for (auto &S : CurrentSegments) { + auto Trimmed = S.Text; + // trim whitespace + Trimmed.erase(0, Trimmed.find_first_not_of(" \t\n\r\f\v")); + if (!Trimmed.empty()) + Trimmed.erase(Trimmed.find_last_not_of(" \t\n\r\f\v") + 1); + if (S.Start == S.End || Trimmed.empty()) { + S.Text.clear(); + S.Tokens.clear(); + S.Words.clear(); + } + } + + // Add to results and accumulate tokens + for (auto &S : CurrentSegments) { + S.Id = static_cast(AllSegments.size()); + AllSegments.push_back(S); + AllTokens.insert(AllTokens.end(), S.Tokens.begin(), S.Tokens.end()); + } + + // Prompt reset logic + if (!ConditionOnPreviousText || Result.Temperature > 0.5f) { + PromptResetSince = static_cast(AllTokens.size()); + } + } + } + + // Prepare final result + TranscribeResult FinalResult; + FinalResult.Segments = AllSegments; + FinalResult.Language = *Options.Language; + + // Concatenate all text + std::stringstream TextStream; + for (const auto &Segment : AllSegments) { + TextStream << Segment.Text; + } + FinalResult.Text = TextStream.str(); + + return FinalResult; +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper_transcribe.h b/plugins/wasi_nn/MLX/model/whisper_transcribe.h new file mode 100644 index 00000000..9948f606 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper_transcribe.h @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include "whisper/decoding.h" +#include "whisper/tokenizer.h" +#include "whisper/whisper.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +// Audio processing constants +constexpr int DefaultSampleRate = 16000; +constexpr int DefaultNFft = 400; +constexpr int DefaultHopLength = 160; +constexpr int DefaultChunkLength = 30; +constexpr int DefaultNSamples = + DefaultChunkLength * DefaultSampleRate; // 480000 samples +constexpr int DefaultNFrames = + DefaultNSamples / DefaultHopLength; // 3000 frames +constexpr int DefaultFramesPerSecond = DefaultSampleRate / DefaultHopLength; +constexpr int DefaultNSamplesPerToken = DefaultHopLength * 2; + +extern const std::vector> LANGUAGES; + +// Word information for word-level timestamps +struct WordInfo { + float Start; + float End; + std::string Word; + float Probability; +}; + +// Segment information +struct TranscribeSegment { + int Id; + int Seek; + float Start; + float End; + std::string Text; + std::vector Tokens; + float Temperature; + float AvgLogprob; + float CompressionRatio; + float NoSpeechProb; + std::vector Words; +}; + +// Complete transcribe result +struct TranscribeResult { + std::string Text; + std::vector Segments; + std::string Language; +}; + +// Audio processing functions +mx::array loadAudio(const std::string &FilePath, + int SampleRate = DefaultSampleRate); +mx::array padOrTrim(const mx::array &Array, int Length = DefaultNSamples, + int Axis = -1); +mx::array logMelSpectrogram(const mx::array &Audio, int NMels = 80, + int Padding = 0); + +// Utility functions +std::string formatTimestamp(float Seconds); +std::optional getEnd(const std::vector &Segments); + +// Decoding functions +DecodingResult +decodeWithFallback(std::shared_ptr Model, + const mx::array &MelSegment, + const DecodingOptions &DecodeOptions, + const std::vector &Temperatures, + std::unique_ptr &Tokenizer, + std::optional CompressionRatioThreshold, + std::optional LogprobThreshold, + std::optional NoSpeechThreshold); + +// Word-level timestamp functions +void addWordTimestamps(std::vector &Segments, + std::shared_ptr Model, + std::unique_ptr &Tokenizer, + const mx::array &Mel, int NumFrames, + const std::string &PrependPunctuations, + const std::string &AppendPunctuations, + float LastSpeechTimestamp = 0.0f); + +// Anomaly detection functions +float wordAnomalyScore(const WordInfo &Word); +bool isSegmentAnomaly(const std::optional &Segment); +std::optional +nextWordsSegment(const std::vector &Segments); + +// Main transcribe function +TranscribeResult +transcribe(const std::variant &Audio, + std::shared_ptr Model, + std::optional Verbose = std::nullopt, + std::variant> Temperature = + std::vector{0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}, + std::optional CompressionRatioThreshold = 2.4f, + std::optional LogprobThreshold = -1.0f, + std::optional NoSpeechThreshold = 0.6f, + bool ConditionOnPreviousText = true, + std::optional InitialPrompt = std::nullopt, + bool WordTimestamps = false, + const std::string &PrependPunctuations = + "\"'“¿([{-\"'.。,,!!??::”)]}、", + const std::string &AppendPunctuations = "\"'.,!?:\")]},", + std::variant> ClipTimestamps = "0", + std::optional HallucinationSilenceThreshold = std::nullopt, + const DecodingOptions &DecodeOptions = DecodingOptions()); + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/prompt/prompt.cpp b/plugins/wasi_nn/MLX/prompt/prompt.cpp new file mode 100644 index 00000000..7a944c29 --- /dev/null +++ b/plugins/wasi_nn/MLX/prompt/prompt.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "prompt/prompt.h" + +#include + +namespace WasmEdge::Host::WASINN::MLX { + +std::string TinyLLaMAPrompt::prepare(std::string Prompt) { + return SystemStart + TextEnd + Prompt + TextEnd + Assistant; +} + +std::string LLaMA2Prompt::prepare(std::string Prompt) { + return InstStart + SysStart + SysEnd + Prompt + TextEnd; +} + +std::string LLaMA3Prompt::prepare(std::string Prompt) { + return PropmtStart + StartHeader + "system" + EndHeader + TextEnd + Prompt + + EndHeader + TextEnd + StartHeader + "user" + EndHeader + Prompt + + TextEnd + StartHeader + "assistant" + EndHeader; +} + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/prompt/prompt.h b/plugins/wasi_nn/MLX/prompt/prompt.h new file mode 100644 index 00000000..d66cfc5b --- /dev/null +++ b/plugins/wasi_nn/MLX/prompt/prompt.h @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +class BasePrompt { +public: + std::string TextEnd; + + virtual std::string prepare(std::string Prompt) { return Prompt + TextEnd; }; +}; + +class TinyLLaMAPrompt : public BasePrompt { +public: + std::string SystemStart; + std::string User; + std::string Assistant; + + TinyLLaMAPrompt() { + SystemStart = "<|system|>"; + Assistant = "<|assistant|>"; + User = "<|user|>"; + TextEnd = ""; + } + + std::string prepare(std::string Prompt) override; +}; + +class LLaMA2Prompt : public BasePrompt { +public: + std::string SysStart; + std::string SysEnd; + std::string InstStart; + + LLaMA2Prompt() { + SysStart = "<>"; + SysEnd = "<>"; + InstStart = "[INST]"; + TextEnd = "[/INST]"; + } + + std::string prepare(std::string Prompt) override; +}; + +class LLaMA3Prompt : public BasePrompt { +public: + std::string PropmtStart; + std::string StartHeader; + std::string EndHeader; + + LLaMA3Prompt() { + PropmtStart = "<|begin_of_text|>"; + StartHeader = "<|start_header_id|>"; + EndHeader = "<|end_header_id|>"; + TextEnd = "<|eot_id|>"; + } + + std::string prepare(std::string Prompt) override; +}; + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/bitnet.patch b/plugins/wasi_nn/bitnet.patch new file mode 100644 index 00000000..b9e506d5 --- /dev/null +++ b/plugins/wasi_nn/bitnet.patch @@ -0,0 +1,39 @@ +diff --git a/utils/codegen_tl1.py b/utils/codegen_tl1.py +index 4c2e7dd..6600a4d 100644 +--- a/utils/codegen_tl1.py ++++ b/utils/codegen_tl1.py +@@ -14,7 +14,9 @@ static void * aligned_malloc(size_t size) {{\n\ + return _aligned_malloc(size, 64);\n\ + #else\n\ + void * ptr = nullptr;\n\ +- posix_memalign(&ptr, 64, size);\n\ ++ if (posix_memalign(&ptr, 64, size) != 0) {{\n\ ++ return nullptr;\n\ ++ }}\n\ + return ptr;\n\ + #endif\n\ + }}\n\ +@@ -201,10 +203,10 @@ def gen_body_core_code(bm, by): + int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\ + int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\ + int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\ +- vec_c[{6}] += vec_v_left_{0}.val[0];\n\ +- vec_c[{6}] += vec_v_right_{0}.val[0];\n\ +- vec_c[{7}] += vec_v_left_{0}.val[1];\n\ +- vec_c[{7}] += vec_v_right_{0}.val[1];\n\ ++ vec_c[{6}] = vaddq_s16(vec_c[{6}], vmovl_s8(vget_low_s8(vec_v_left_{0}.val[0])));\n\ ++ vec_c[{6}] = vaddq_s16(vec_c[{6}], vmovl_s8(vget_low_s8(vec_v_right_{0}.val[0])));\n\ ++ vec_c[{7}] = vaddq_s16(vec_c[{7}], vmovl_s8(vget_low_s8(vec_v_left_{0}.val[1])));\n\ ++ vec_c[{7}] = vaddq_s16(vec_c[{7}], vmovl_s8(vget_low_s8(vec_v_right_{0}.val[1])));\n\ + ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1) + + all_code = "".join([all_code, core_code]) +@@ -232,7 +234,7 @@ inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ + #ifdef __ARM_NEON\n\ + const int KK = BBK{0} / 2;\n\ + const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\ +- const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\ ++ const int16x8_t vec_zero = vdupq_n_s16(0x0000);\n\ + int8x16_t vec_lut[2 * KK];\n\ + ".format(pre, BM, BK) + diff --git a/plugins/wasi_nn/wasinn_bitnet.cpp b/plugins/wasi_nn/wasinn_bitnet.cpp new file mode 100644 index 00000000..15564a70 --- /dev/null +++ b/plugins/wasi_nn/wasinn_bitnet.cpp @@ -0,0 +1,2438 @@ +#include "wasinn_bitnet.h" +#include "wasinnenv.h" +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET +#include "simdjson.h" +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::BitNet { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET + +namespace { + +// Macro for logging debug message. +#define LOG_DEBUG(Debug, ...) \ + if (Debug) { \ + spdlog::info("[WASI-NN][Debug] BitNet backend: "sv __VA_ARGS__); \ + } + +// Macro for logging info message. +#define LOG_INFO(Info, ...) \ + if (Info) { \ + spdlog::info("[WASI-NN] BitNet backend: "sv __VA_ARGS__); \ + } + +// Macro for logging warning message. +#define LOG_WARN(...) spdlog::warn("[WASI-NN] BitNet backend: "sv __VA_ARGS__); + +// Macro for logging error message. +#define LOG_ERROR(...) \ + spdlog::error("[WASI-NN] BitNet backend: "sv __VA_ARGS__); + +// Macro for logging an error message and returning. +#define RET_ERROR(Error, ...) \ + spdlog::error("[WASI-NN] BitNet backend: "sv __VA_ARGS__); \ + return Error; + +// Llama logging callback. +void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, + void *UserData) { + Graph &GraphRef = *reinterpret_cast(UserData); + if (!GraphRef.EnableLog) { + return; + } + std::string Text(LogText); + // Remove the trailing newlines. + Text = Text.erase(Text.find_last_not_of("\n") + 1); + // Skip for "." + if (Text == ".") { + return; + } + if (LogLevel == GGML_LOG_LEVEL_ERROR) { + spdlog::error("[WASI-NN] BitNet.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_WARN) { + spdlog::warn("[WASI-NN] BitNet.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_INFO) { + spdlog::info("[WASI-NN] BitNet.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { + spdlog::debug("[WASI-NN] BitNet.cpp: {}"sv, Text); + } +} + +// >>>>>>>> Metadata related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Helper function to parse comma-separated string into vector. +void stringToList(const std::string &Raw, std::vector &Out) { + std::string Copy = Raw; + std::replace(Copy.begin(), Copy.end(), ',', ' '); + std::stringstream SS(Copy); + Out.clear(); + while (SS.good()) { + int TmpInt; + SS >> TmpInt; + Out.push_back(TmpInt); + } +} + +// Parse metadata from JSON. +ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, + const std::string &Metadata, bool *IsModelUpdated = nullptr, + bool *IsContextUpdated = nullptr, + bool *IsSamplerUpdated = nullptr) noexcept { + // Parse metadata from the json. + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) + } + + // Get the current llama parameters. + int64_t PrevNGPULayers = GraphRef.Params.n_gpu_layers; + int64_t PrevMainGpu = GraphRef.Params.main_gpu; + int64_t PrevThreads = GraphRef.Params.cpuparams.n_threads; + bool PrevFlashAttn = GraphRef.Params.flash_attn; + int64_t PrevCtxSize = GraphRef.Params.n_ctx; + bool PrevEmbedding = GraphRef.Params.embedding; + // Get the current sampler parameters. + double PrevTemp = GraphRef.Params.sparams.temp; + double PrevTopP = GraphRef.Params.sparams.top_p; + double PrevRepeatPenalty = GraphRef.Params.sparams.penalty_repeat; + double PrevPresencePenalty = GraphRef.Params.sparams.penalty_present; + double PrevFrequencyPenalty = GraphRef.Params.sparams.penalty_freq; + std::string PrevGrammar = GraphRef.Params.sparams.grammar; + uint64_t PrevSeed = GraphRef.Params.sparams.seed; + + // The plugin parameters. + if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-log option."sv) + } + } + if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-debug-log option."sv) + } + } + + // The model parameters. + if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { + int64_t MainGPU; + auto Err = Doc["main-gpu"].get().get(MainGPU); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the main-gpu option."sv) + } + GraphRef.Params.main_gpu = static_cast(MainGPU); + } + if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { + int64_t NGPULayers; + auto Err = Doc["n-gpu-layers"].get().get(NGPULayers); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-gpu-layers option."sv) + } + GraphRef.Params.n_gpu_layers = static_cast(NGPULayers); + } + if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { + // The TensorSplit is a comma-separated list of non-negative values. + // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. + + // helper function `stringToList` cannot be used here since tensor-split + // needs a fixed-size array with validation checks. + std::string_view TSV; + auto Err = Doc["tensor-split"].get().get(TSV); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tensor-split option."sv) + } + std::string TS(TSV); + std::replace(TS.begin(), TS.end(), ',', ' '); + std::stringstream SS(TS); + std::memset(GraphRef.Params.tensor_split, 0, + sizeof(GraphRef.Params.tensor_split)); + uint32_t TensorSplitSize = 0; + while (SS.good()) { + float TmpTensor; + SS >> TmpTensor; + GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; + } + size_t NDevices = llama_max_devices(); + if (TensorSplitSize > NDevices) { + RET_ERROR( + ErrNo::InvalidArgument, + "Number of Tensor-Split is larger than MaxDevices, please reduce "sv + "the size of tensor-split."sv) + } + for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { + GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; + } + } + if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { + auto Err = Doc["embedding"].get().get(GraphRef.Params.embedding); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embedding option."sv) + } + } + if (Doc.at_key("split-mode").error() == simdjson::SUCCESS) { + std::string_view SplitMode; + auto Err = Doc["split-mode"].get().get(SplitMode); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the split-mode option."sv) + } + if (SplitMode == "none"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (SplitMode == "layer"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (SplitMode == "row"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + RET_ERROR(ErrNo::InvalidArgument, + "Unknown split-mode: {}. Valid: none, layer, row."sv, SplitMode) + } + } + + // The context parameters. + if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { + int64_t CtxSize; + auto Err = Doc["ctx-size"].get().get(CtxSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ctx-size option."sv) + } + GraphRef.Params.n_ctx = static_cast(CtxSize); + } + if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { + int64_t BatchSize; + auto Err = Doc["batch-size"].get().get(BatchSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the batch-size option."sv) + } + GraphRef.Params.n_batch = static_cast(BatchSize); + } + if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { + int64_t UBatchSize; + auto Err = Doc["ubatch-size"].get().get(UBatchSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ubatch-size option."sv) + } + GraphRef.Params.n_ubatch = static_cast(UBatchSize); + } + if (Doc.at_key("n-keep").error() == simdjson::SUCCESS) { + int64_t NKeep; + auto Err = Doc["n-keep"].get().get(NKeep); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-keep option."sv) + } + GraphRef.Params.n_keep = static_cast(NKeep); + } + if (Doc.at_key("n-chunks").error() == simdjson::SUCCESS) { + int64_t NChunks; + auto Err = Doc["n-chunks"].get().get(NChunks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-chunks option."sv) + } + GraphRef.Params.n_chunks = static_cast(NChunks); + } + if (Doc.at_key("n-parallel").error() == simdjson::SUCCESS) { + int64_t NParallel; + auto Err = Doc["n-parallel"].get().get(NParallel); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-parallel option."sv) + } + GraphRef.Params.n_parallel = static_cast(NParallel); + } + if (Doc.at_key("n-sequences").error() == simdjson::SUCCESS) { + int64_t NSequences; + auto Err = Doc["n-sequences"].get().get(NSequences); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-sequences option."sv) + } + GraphRef.Params.n_sequences = static_cast(NSequences); + } + if (Doc.at_key("grp-attn-n").error() == simdjson::SUCCESS) { + int64_t GrpAttnN; + auto Err = Doc["grp-attn-n"].get().get(GrpAttnN); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grp-attn-n option."sv) + } + GraphRef.Params.grp_attn_n = static_cast(GrpAttnN); + } + if (Doc.at_key("grp-attn-w").error() == simdjson::SUCCESS) { + int64_t GrpAttnW; + auto Err = Doc["grp-attn-w"].get().get(GrpAttnW); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grp-attn-w option."sv) + } + GraphRef.Params.grp_attn_w = static_cast(GrpAttnW); + } + if (Doc.at_key("n-print").error() == simdjson::SUCCESS) { + int64_t NPrint; + auto Err = Doc["n-print"].get().get(NPrint); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-print option."sv) + } + GraphRef.Params.n_print = static_cast(NPrint); + } + if (Doc.at_key("rope-freq-base").error() == simdjson::SUCCESS) { + double RopeFreqBase; + auto Err = Doc["rope-freq-base"].get().get(RopeFreqBase); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the rope-freq-base option."sv) + } + GraphRef.Params.rope_freq_base = static_cast(RopeFreqBase); + } + if (Doc.at_key("rope-freq-scale").error() == simdjson::SUCCESS) { + double RopeFreqScale; + auto Err = Doc["rope-freq-scale"].get().get(RopeFreqScale); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the rope-freq-scale option."sv) + } + GraphRef.Params.rope_freq_scale = static_cast(RopeFreqScale); + } + if (Doc.at_key("yarn-ext-factor").error() == simdjson::SUCCESS) { + double YarnExtFactor; + auto Err = Doc["yarn-ext-factor"].get().get(YarnExtFactor); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-ext-factor option."sv) + } + GraphRef.Params.yarn_ext_factor = static_cast(YarnExtFactor); + } + if (Doc.at_key("yarn-attn-factor").error() == simdjson::SUCCESS) { + double YarnAttnFactor; + auto Err = Doc["yarn-attn-factor"].get().get(YarnAttnFactor); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-attn-factor option."sv) + } + GraphRef.Params.yarn_attn_factor = static_cast(YarnAttnFactor); + } + if (Doc.at_key("yarn-beta-fast").error() == simdjson::SUCCESS) { + double YarnBetaFast; + auto Err = Doc["yarn-beta-fast"].get().get(YarnBetaFast); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-beta-fast option."sv) + } + GraphRef.Params.yarn_beta_fast = static_cast(YarnBetaFast); + } + if (Doc.at_key("yarn-beta-slow").error() == simdjson::SUCCESS) { + double YarnBetaSlow; + auto Err = Doc["yarn-beta-slow"].get().get(YarnBetaSlow); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-beta-slow option."sv) + } + GraphRef.Params.yarn_beta_slow = static_cast(YarnBetaSlow); + } + if (Doc.at_key("yarn-orig-ctx").error() == simdjson::SUCCESS) { + int64_t YarnOrigCtx; + auto Err = Doc["yarn-orig-ctx"].get().get(YarnOrigCtx); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-orig-ctx option."sv) + } + GraphRef.Params.yarn_orig_ctx = static_cast(YarnOrigCtx); + } + if (Doc.at_key("defrag-thold").error() == simdjson::SUCCESS) { + double DefragThold; + auto Err = Doc["defrag-thold"].get().get(DefragThold); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the defrag-thold option."sv) + } + GraphRef.Params.defrag_thold = static_cast(DefragThold); + } + if (Doc.at_key("mask-valid").error() == simdjson::SUCCESS) { + auto Err = + Doc["mask-valid"].get().get(GraphRef.Params.cpuparams.mask_valid); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mask-valid option."sv) + } + } + if (Doc.at_key("priority").error() == simdjson::SUCCESS) { + int64_t Priority; + auto Err = Doc["priority"].get().get(Priority); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the priority option."sv) + } + GraphRef.Params.cpuparams.priority = + static_cast(Priority); + } + if (Doc.at_key("strict-cpu").error() == simdjson::SUCCESS) { + auto Err = + Doc["strict-cpu"].get().get(GraphRef.Params.cpuparams.strict_cpu); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the strict-cpu option."sv) + } + } + if (Doc.at_key("poll").error() == simdjson::SUCCESS) { + int64_t Poll; + auto Err = Doc["poll"].get().get(Poll); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the poll option."sv) + } + GraphRef.Params.cpuparams.poll = static_cast(Poll); + } + if (Doc.at_key("mask-valid-batch").error() == simdjson::SUCCESS) { + auto Err = Doc["mask-valid-batch"].get().get( + GraphRef.Params.cpuparams_batch.mask_valid); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mask-valid-batch option."sv) + } + } + if (Doc.at_key("priority-batch").error() == simdjson::SUCCESS) { + int64_t Priority; + auto Err = Doc["priority-batch"].get().get(Priority); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the priority-batch option."sv) + } + GraphRef.Params.cpuparams_batch.priority = + static_cast(Priority); + } + if (Doc.at_key("strict-cpu-batch").error() == simdjson::SUCCESS) { + auto Err = Doc["strict-cpu-batch"].get().get( + GraphRef.Params.cpuparams_batch.strict_cpu); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the strict-cpu-batch option."sv) + } + } + if (Doc.at_key("poll-batch").error() == simdjson::SUCCESS) { + int64_t Poll; + auto Err = Doc["poll-batch"].get().get(Poll); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the poll-batch option."sv) + } + GraphRef.Params.cpuparams_batch.poll = static_cast(Poll); + } + if (Doc.at_key("numa").error() == simdjson::SUCCESS) { + int64_t Numa; + auto Err = Doc["numa"].get().get(Numa); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the numa option."sv) + } + GraphRef.Params.numa = static_cast(Numa); + } + if (Doc.at_key("rope-scaling-type").error() == simdjson::SUCCESS) { + int64_t RopeScalingType; + auto Err = Doc["rope-scaling-type"].get().get(RopeScalingType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the rope-scaling-type option."sv) + } + GraphRef.Params.rope_scaling_type = + static_cast(RopeScalingType); + } + if (Doc.at_key("pooling-type").error() == simdjson::SUCCESS) { + int64_t PoolingType; + auto Err = Doc["pooling-type"].get().get(PoolingType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the pooling-type option."sv) + } + GraphRef.Params.pooling_type = + static_cast(PoolingType); + } + if (Doc.at_key("attention-type").error() == simdjson::SUCCESS) { + int64_t AttentionType; + auto Err = Doc["attention-type"].get().get(AttentionType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the attention-type option."sv) + } + GraphRef.Params.attention_type = + static_cast(AttentionType); + } + if (Doc.at_key("threads").error() == simdjson::SUCCESS) { + int64_t NThreads; + auto Err = Doc["threads"].get().get(NThreads); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the threads option."sv) + } + GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); + } + if (Doc.at_key("threads-batch").error() == simdjson::SUCCESS) { + int64_t NThreadsBatch; + auto Err = Doc["threads-batch"].get().get(NThreadsBatch); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the threads-batch option."sv) + } + GraphRef.Params.cpuparams_batch.n_threads = + static_cast(NThreadsBatch); + } + + // The sampling parameters. + if (Doc.at_key("n-prev").error() == simdjson::SUCCESS) { + int64_t NPrev; + auto Err = Doc["n-prev"].get().get(NPrev); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n_prev option."sv) + } + GraphRef.Params.sparams.n_prev = static_cast(NPrev); + } + if (Doc.at_key("top-k").error() == simdjson::SUCCESS) { + int64_t TopK; + auto Err = Doc["top-k"].get().get(TopK); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the top-k option."sv) + } + GraphRef.Params.sparams.top_k = static_cast(TopK); + } + if (Doc.at_key("min-p").error() == simdjson::SUCCESS) { + double MinP; + auto Err = Doc["min-p"].get().get(MinP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the min-p option."sv) + } + GraphRef.Params.sparams.min_p = static_cast(MinP); + } + if (Doc.at_key("typ-p").error() == simdjson::SUCCESS) { + double TypP; + auto Err = Doc["typ-p"].get().get(TypP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the typ-p option."sv) + } + GraphRef.Params.sparams.typ_p = static_cast(TypP); + } + if (Doc.at_key("penalize-nl").error() == simdjson::SUCCESS) { + auto Err = + Doc["penalize-nl"].get().get(GraphRef.Params.sparams.penalize_nl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the penalize-nl option."sv) + } + } + if (Doc.at_key("tfs").error() == simdjson::SUCCESS) { + double TfsZ; + auto Err = Doc["tfs"].get().get(TfsZ); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the tfs option."sv) + } + GraphRef.Params.sparams.tfs_z = static_cast(TfsZ); + } + if (Doc.at_key("dynatemp-range").error() == simdjson::SUCCESS) { + double DynaTempRange; + auto Err = Doc["dynatemp-range"].get().get(DynaTempRange); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dynatemp-range option."sv) + } + GraphRef.Params.sparams.dynatemp_range = static_cast(DynaTempRange); + } + if (Doc.at_key("dynatemp-exponent").error() == simdjson::SUCCESS) { + double DynaTempExponent; + auto Err = Doc["dynatemp-exponent"].get().get(DynaTempExponent); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dynatemp-exponent option."sv) + } + GraphRef.Params.sparams.dynatemp_exponent = + static_cast(DynaTempExponent); + } + if (Doc.at_key("last-n-penalty").error() == simdjson::SUCCESS) { + int64_t LastNPenalty; + auto Err = Doc["last-n-penalty"].get().get(LastNPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the last-n-penalty option."sv) + } + GraphRef.Params.sparams.penalty_last_n = static_cast(LastNPenalty); + } + if (Doc.at_key("temp").error() == simdjson::SUCCESS) { + double Temp; + auto Err = Doc["temp"].get().get(Temp); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the temp option."sv) + } + GraphRef.Params.sparams.temp = static_cast(std::max(0.0, Temp)); + } + if (Doc.at_key("top-p").error() == simdjson::SUCCESS) { + double TopP; + auto Err = Doc["top-p"].get().get(TopP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the top-p option."sv) + } + GraphRef.Params.sparams.top_p = static_cast(std::max(0.0, TopP)); + } + if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { + double RepeatPenalty; + auto Err = Doc["repeat-penalty"].get().get(RepeatPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the repeat-penalty option."sv) + } + GraphRef.Params.sparams.penalty_repeat = + static_cast(std::max(0.0, RepeatPenalty)); + } + if (Doc.at_key("presence-penalty").error() == simdjson::SUCCESS) { + double PresencePenalty; + auto Err = Doc["presence-penalty"].get().get(PresencePenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the presence-penalty option."sv) + } + GraphRef.Params.sparams.penalty_present = + static_cast(std::max(0.0, PresencePenalty)); + } + if (Doc.at_key("frequency-penalty").error() == simdjson::SUCCESS) { + double FrequencyPenalty; + auto Err = Doc["frequency-penalty"].get().get(FrequencyPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the frequency-penalty option."sv) + } + GraphRef.Params.sparams.penalty_freq = + static_cast(std::max(0.0, FrequencyPenalty)); + } + if (Doc.at_key("mirostat").error() == simdjson::SUCCESS) { + int64_t Mirostat; + auto Err = Doc["mirostat"].get().get(Mirostat); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat option."sv) + } + GraphRef.Params.sparams.mirostat = static_cast(Mirostat); + } + if (Doc.at_key("mirostat-eta").error() == simdjson::SUCCESS) { + double MirostatEta; + auto Err = Doc["mirostat-eta"].get().get(MirostatEta); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat-eta option."sv) + } + GraphRef.Params.sparams.mirostat_eta = static_cast(MirostatEta); + } + if (Doc.at_key("mirostat-ent").error() == simdjson::SUCCESS) { + double MirostatEnt; + auto Err = Doc["mirostat-ent"].get().get(MirostatEnt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat-ent option."sv) + } + GraphRef.Params.sparams.mirostat_tau = static_cast(MirostatEnt); + } + if (Doc.at_key("ignore-eos").error() == simdjson::SUCCESS) { + auto Err = + Doc["ignore-eos"].get().get(GraphRef.Params.sparams.ignore_eos); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ignore-eos option."sv) + } + } + if (Doc.at_key("no-perf-sampling").error() == simdjson::SUCCESS) { + auto Err = Doc["no-perf-sampling"].get().get( + GraphRef.Params.sparams.no_perf); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the no-perf-sampling option."sv) + } + } + if (Doc.at_key("grammar").error() == simdjson::SUCCESS) { + std::string_view Grammar; + auto Err = Doc["grammar"].get().get(Grammar); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grammar option."sv) + } + GraphRef.Params.sparams.grammar = Grammar; + } + if (Doc.at_key("seed").error() == simdjson::SUCCESS) { + uint64_t Seed; + auto Err = Doc["seed"].get().get(Seed); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the seed option."sv) + } + GraphRef.Params.sparams.seed = static_cast(Seed); + } + // The speculative parameters. + if (Doc.at_key("n-gpu-layers-draft").error() == simdjson::SUCCESS) { + int64_t NGPULayersDraft; + auto Err = Doc["n-gpu-layers-draft"].get().get(NGPULayersDraft); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-gpu-layers-draft option."sv) + } + GraphRef.Params.n_gpu_layers_draft = static_cast(NGPULayersDraft); + } + if (Doc.at_key("p-split").error() == simdjson::SUCCESS) { + double PSplit; + auto Err = Doc["p-split"].get().get(PSplit); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the p-split option."sv) + } + GraphRef.Params.p_split = static_cast(PSplit); + } + + // The config parameters. + if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { + auto Err = Doc["stream-stdout"].get().get(ConfRef.StreamStdout); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the stream-stdout option."sv) + } + } + if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { + auto Err = Doc["n-predict"].get().get(ConfRef.NPredict); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-predict option."sv) + } + } + if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { + std::string_view ReversePrompt; + auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reverse-prompt option."sv) + } + ConfRef.ReversePrompt = ReversePrompt; + } + if (Doc.at_key("model-alias").error() == simdjson::SUCCESS) { + std::string_view ModelAlias; + auto Err = Doc["model-alias"].get().get(ModelAlias); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-alias option."sv) + } + GraphRef.Params.model_alias = ModelAlias; + } + if (Doc.at_key("model-url").error() == simdjson::SUCCESS) { + std::string_view ModelUrl; + auto Err = Doc["model-url"].get().get(ModelUrl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-url option."sv) + } + GraphRef.Params.model_url = ModelUrl; + } + if (Doc.at_key("hf-token").error() == simdjson::SUCCESS) { + std::string_view HfToken; + auto Err = Doc["hf-token"].get().get(HfToken); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-token option."sv) + } + GraphRef.Params.hf_token = HfToken; + } + if (Doc.at_key("hf-repo").error() == simdjson::SUCCESS) { + std::string_view HfRepo; + auto Err = Doc["hf-repo"].get().get(HfRepo); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-repo option."sv) + } + GraphRef.Params.hf_repo = HfRepo; + } + if (Doc.at_key("hf-file").error() == simdjson::SUCCESS) { + std::string_view HfFile; + auto Err = Doc["hf-file"].get().get(HfFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-file option."sv) + } + GraphRef.Params.hf_file = HfFile; + } + if (Doc.at_key("prompt-file").error() == simdjson::SUCCESS) { + std::string_view PromptFile; + auto Err = Doc["prompt-file"].get().get(PromptFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-file option."sv) + } + GraphRef.Params.prompt_file = PromptFile; + } + if (Doc.at_key("path-prompt-cache").error() == simdjson::SUCCESS) { + std::string_view PathPromptCache; + auto Err = + Doc["path-prompt-cache"].get().get(PathPromptCache); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the path-prompt-cache option."sv) + } + GraphRef.Params.path_prompt_cache = PathPromptCache; + } + if (Doc.at_key("input-prefix").error() == simdjson::SUCCESS) { + std::string_view InputPrefix; + auto Err = Doc["input-prefix"].get().get(InputPrefix); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-prefix option."sv) + } + GraphRef.Params.input_prefix = InputPrefix; + } + if (Doc.at_key("input-suffix").error() == simdjson::SUCCESS) { + std::string_view InputSuffix; + auto Err = Doc["input-suffix"].get().get(InputSuffix); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-suffix option."sv) + } + GraphRef.Params.input_suffix = InputSuffix; + } + if (Doc.at_key("lookup-cache-static").error() == simdjson::SUCCESS) { + std::string_view LookupCacheStatic; + auto Err = Doc["lookup-cache-static"].get().get( + LookupCacheStatic); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lookup-cache-static option."sv) + } + GraphRef.Params.lookup_cache_static = LookupCacheStatic; + } + if (Doc.at_key("lookup-cache-dynamic").error() == simdjson::SUCCESS) { + std::string_view LookupCacheDynamic; + auto Err = Doc["lookup-cache-dynamic"].get().get( + LookupCacheDynamic); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lookup-cache-dynamic option."sv) + } + GraphRef.Params.lookup_cache_dynamic = LookupCacheDynamic; + } + if (Doc.at_key("logits-file").error() == simdjson::SUCCESS) { + std::string_view LogitsFile; + auto Err = Doc["logits-file"].get().get(LogitsFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the logits-file option."sv) + } + GraphRef.Params.logits_file = LogitsFile; + } + if (Doc.at_key("lora-init-without-apply").error() == simdjson::SUCCESS) { + auto Err = Doc["lora-init-without-apply"].get().get( + GraphRef.Params.lora_init_without_apply); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lora-init-without-apply option."sv) + } + } + if (Doc.at_key("verbosity").error() == simdjson::SUCCESS) { + int64_t Verbosity; + auto Err = Doc["verbosity"].get().get(Verbosity); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the verbosity option."sv) + } + GraphRef.Params.verbosity = static_cast(Verbosity); + } + if (Doc.at_key("control-vector-layer-start").error() == simdjson::SUCCESS) { + int64_t ControlVectorLayerStart; + auto Err = Doc["control-vector-layer-start"].get().get( + ControlVectorLayerStart); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the control-vector-layer-start option."sv) + } + GraphRef.Params.control_vector_layer_start = + static_cast(ControlVectorLayerStart); + } + if (Doc.at_key("control-vector-layer-end").error() == simdjson::SUCCESS) { + int64_t ControlVectorLayerEnd; + auto Err = Doc["control-vector-layer-end"].get().get( + ControlVectorLayerEnd); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the control-vector-layer-end option."sv) + } + GraphRef.Params.control_vector_layer_end = + static_cast(ControlVectorLayerEnd); + } + if (Doc.at_key("ppl-stride").error() == simdjson::SUCCESS) { + int64_t PplStride; + auto Err = Doc["ppl-stride"].get().get(PplStride); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ppl-stride option."sv) + } + GraphRef.Params.ppl_stride = static_cast(PplStride); + } + if (Doc.at_key("ppl-output-type").error() == simdjson::SUCCESS) { + int64_t PplOutputType; + auto Err = Doc["ppl-output-type"].get().get(PplOutputType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ppl-output-type option."sv) + } + GraphRef.Params.ppl_output_type = static_cast(PplOutputType); + } + if (Doc.at_key("hellaswag").error() == simdjson::SUCCESS) { + auto Err = Doc["hellaswag"].get().get(GraphRef.Params.hellaswag); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hellaswag option."sv) + } + } + if (Doc.at_key("hellaswag-tasks").error() == simdjson::SUCCESS) { + uint64_t HellaswagTasks; + auto Err = Doc["hellaswag-tasks"].get().get(HellaswagTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hellaswag-tasks option."sv) + } + GraphRef.Params.hellaswag_tasks = HellaswagTasks; + } + if (Doc.at_key("winogrande").error() == simdjson::SUCCESS) { + auto Err = Doc["winogrande"].get().get(GraphRef.Params.winogrande); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the winogrande option."sv) + } + } + if (Doc.at_key("winogrande-tasks").error() == simdjson::SUCCESS) { + uint64_t WinograndeTasks; + auto Err = Doc["winogrande-tasks"].get().get(WinograndeTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the winogrande-tasks option."sv) + } + GraphRef.Params.winogrande_tasks = WinograndeTasks; + } + if (Doc.at_key("multiple-choice").error() == simdjson::SUCCESS) { + auto Err = + Doc["multiple-choice"].get().get(GraphRef.Params.multiple_choice); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiple-choice option."sv) + } + } + if (Doc.at_key("multiple-choice-tasks").error() == simdjson::SUCCESS) { + uint64_t MultipleChoiceTasks; + auto Err = + Doc["multiple-choice-tasks"].get().get(MultipleChoiceTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiple-choice-tasks option."sv) + } + GraphRef.Params.multiple_choice_tasks = MultipleChoiceTasks; + } + if (Doc.at_key("kl-divergence").error() == simdjson::SUCCESS) { + auto Err = + Doc["kl-divergence"].get().get(GraphRef.Params.kl_divergence); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the kl-divergence option."sv) + } + } + if (Doc.at_key("usage").error() == simdjson::SUCCESS) { + auto Err = Doc["usage"].get().get(GraphRef.Params.usage); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the usage option."sv) + } + } + if (Doc.at_key("use-color").error() == simdjson::SUCCESS) { + auto Err = Doc["use-color"].get().get(GraphRef.Params.use_color); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-color option."sv) + } + } + if (Doc.at_key("special").error() == simdjson::SUCCESS) { + auto Err = Doc["special"].get().get(GraphRef.Params.special); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the special option."sv) + } + } + if (Doc.at_key("interactive").error() == simdjson::SUCCESS) { + auto Err = Doc["interactive"].get().get(GraphRef.Params.interactive); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the interactive option."sv) + } + } + if (Doc.at_key("interactive-first").error() == simdjson::SUCCESS) { + auto Err = Doc["interactive-first"].get().get( + GraphRef.Params.interactive_first); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the interactive-first option."sv) + } + } + if (Doc.at_key("prompt-cache-all").error() == simdjson::SUCCESS) { + auto Err = Doc["prompt-cache-all"].get().get( + GraphRef.Params.prompt_cache_all); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-cache-all option."sv) + } + } + if (Doc.at_key("prompt-cache-ro").error() == simdjson::SUCCESS) { + auto Err = + Doc["prompt-cache-ro"].get().get(GraphRef.Params.prompt_cache_ro); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-cache-ro option."sv) + } + } + if (Doc.at_key("escape").error() == simdjson::SUCCESS) { + auto Err = Doc["escape"].get().get(GraphRef.Params.escape); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the escape option."sv) + } + } + if (Doc.at_key("multiline-input").error() == simdjson::SUCCESS) { + auto Err = + Doc["multiline-input"].get().get(GraphRef.Params.multiline_input); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiline-input option."sv) + } + } + if (Doc.at_key("simple-io").error() == simdjson::SUCCESS) { + auto Err = Doc["simple-io"].get().get(GraphRef.Params.simple_io); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the simple-io option."sv) + } + } + if (Doc.at_key("cont-batching").error() == simdjson::SUCCESS) { + auto Err = + Doc["cont-batching"].get().get(GraphRef.Params.cont_batching); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cont-batching option."sv) + } + } + if (Doc.at_key("flash-attn").error() == simdjson::SUCCESS) { + auto Err = Doc["flash-attn"].get().get(GraphRef.Params.flash_attn); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the flash-attn option."sv) + } + } + if (Doc.at_key("no-perf").error() == simdjson::SUCCESS) { + auto Err = Doc["no-perf"].get().get(GraphRef.Params.no_perf); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the no-perf option."sv) + } + } + if (Doc.at_key("ctx-shift").error() == simdjson::SUCCESS) { + auto Err = Doc["ctx-shift"].get().get(GraphRef.Params.ctx_shift); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ctx-shift option."sv) + } + } + if (Doc.at_key("input-prefix-bos").error() == simdjson::SUCCESS) { + auto Err = Doc["input-prefix-bos"].get().get( + GraphRef.Params.input_prefix_bos); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-prefix-bos option."sv) + } + } + if (Doc.at_key("use-mlock").error() == simdjson::SUCCESS) { + auto Err = Doc["use-mlock"].get().get(GraphRef.Params.use_mlock); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-mlock option."sv) + } + } + if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { + auto Err = Doc["use-mmap"].get().get(GraphRef.Params.use_mmap); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-mmap option."sv) + } + } + if (Doc.at_key("verbose-prompt").error() == simdjson::SUCCESS) { + auto Err = + Doc["verbose-prompt"].get().get(GraphRef.Params.verbose_prompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the verbose-prompt option."sv) + } + } + if (Doc.at_key("display-prompt").error() == simdjson::SUCCESS) { + auto Err = + Doc["display-prompt"].get().get(GraphRef.Params.display_prompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the display-prompt option."sv) + } + } + if (Doc.at_key("no-kv-offload").error() == simdjson::SUCCESS) { + auto Err = + Doc["no-kv-offload"].get().get(GraphRef.Params.no_kv_offload); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the no-kv-offload option."sv) + } + } + if (Doc.at_key("warmup").error() == simdjson::SUCCESS) { + auto Err = Doc["warmup"].get().get(GraphRef.Params.warmup); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the warmup option."sv) + } + } + if (Doc.at_key("check-tensors").error() == simdjson::SUCCESS) { + auto Err = + Doc["check-tensors"].get().get(GraphRef.Params.check_tensors); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the check-tensors option."sv) + } + } + if (Doc.at_key("cache-type-k").error() == simdjson::SUCCESS) { + std::string_view CacheTypeK; + auto Err = Doc["cache-type-k"].get().get(CacheTypeK); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cache-type-k option."sv) + } + GraphRef.Params.cache_type_k = CacheTypeK; + } + if (Doc.at_key("cache-type-v").error() == simdjson::SUCCESS) { + std::string_view CacheTypeV; + auto Err = Doc["cache-type-v"].get().get(CacheTypeV); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cache-type-v option."sv) + } + GraphRef.Params.cache_type_v = CacheTypeV; + } + if (Doc.at_key("embd-normalize").error() == simdjson::SUCCESS) { + int64_t EmbdNormalize; + auto Err = Doc["embd-normalize"].get().get(EmbdNormalize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-normalize option."sv) + } + GraphRef.Params.embd_normalize = static_cast(EmbdNormalize); + } + if (Doc.at_key("embd-out").error() == simdjson::SUCCESS) { + std::string_view EmbdOut; + auto Err = Doc["embd-out"].get().get(EmbdOut); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-out option."sv) + } + GraphRef.Params.embd_out = EmbdOut; + } + if (Doc.at_key("embd-sep").error() == simdjson::SUCCESS) { + std::string_view EmbdSep; + auto Err = Doc["embd-sep"].get().get(EmbdSep); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-sep option."sv) + } + GraphRef.Params.embd_sep = EmbdSep; + } + if (Doc.at_key("reranking").error() == simdjson::SUCCESS) { + auto Err = Doc["reranking"].get().get(GraphRef.Params.reranking); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reranking option."sv) + } + } + if (Doc.at_key("port").error() == simdjson::SUCCESS) { + int64_t Port; + auto Err = Doc["port"].get().get(Port); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the port option."sv) + } + GraphRef.Params.port = static_cast(Port); + } + if (Doc.at_key("timeout-read").error() == simdjson::SUCCESS) { + int64_t TimeoutRead; + auto Err = Doc["timeout-read"].get().get(TimeoutRead); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the timeout-read option."sv) + } + GraphRef.Params.timeout_read = static_cast(TimeoutRead); + } + if (Doc.at_key("timeout-write").error() == simdjson::SUCCESS) { + int64_t TimeoutWrite; + auto Err = Doc["timeout-write"].get().get(TimeoutWrite); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the timeout-write option."sv) + } + GraphRef.Params.timeout_write = static_cast(TimeoutWrite); + } + if (Doc.at_key("n-threads-http").error() == simdjson::SUCCESS) { + int64_t NThreadsHttp; + auto Err = Doc["n-threads-http"].get().get(NThreadsHttp); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-threads-http option."sv) + } + GraphRef.Params.n_threads_http = static_cast(NThreadsHttp); + } + if (Doc.at_key("n-cache-reuse").error() == simdjson::SUCCESS) { + int64_t NCacheReuse; + auto Err = Doc["n-cache-reuse"].get().get(NCacheReuse); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-cache-reuse option."sv) + } + GraphRef.Params.n_cache_reuse = static_cast(NCacheReuse); + } + if (Doc.at_key("hostname").error() == simdjson::SUCCESS) { + std::string_view Hostname; + auto Err = Doc["hostname"].get().get(Hostname); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hostname option."sv) + } + GraphRef.Params.hostname = Hostname; + } + if (Doc.at_key("public-path").error() == simdjson::SUCCESS) { + std::string_view PublicPath; + auto Err = Doc["public-path"].get().get(PublicPath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the public-path option."sv) + } + GraphRef.Params.public_path = PublicPath; + } + if (Doc.at_key("chat-template").error() == simdjson::SUCCESS) { + std::string_view ChatTemplate; + auto Err = Doc["chat-template"].get().get(ChatTemplate); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the chat-template option."sv) + } + GraphRef.Params.chat_template = ChatTemplate; + } + if (Doc.at_key("enable-chat-template").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-chat-template"].get().get( + GraphRef.Params.enable_chat_template); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-chat-template option."sv) + } + } + if (Doc.at_key("ssl-file-key").error() == simdjson::SUCCESS) { + std::string_view SslFileKey; + auto Err = Doc["ssl-file-key"].get().get(SslFileKey); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ssl-file-key option."sv) + } + GraphRef.Params.ssl_file_key = SslFileKey; + } + if (Doc.at_key("ssl-file-cert").error() == simdjson::SUCCESS) { + std::string_view SslFileCert; + auto Err = Doc["ssl-file-cert"].get().get(SslFileCert); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ssl-file-cert option."sv) + } + GraphRef.Params.ssl_file_cert = SslFileCert; + } + if (Doc.at_key("webui").error() == simdjson::SUCCESS) { + auto Err = Doc["webui"].get().get(GraphRef.Params.webui); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the webui option."sv) + } + } + if (Doc.at_key("endpoint-slots").error() == simdjson::SUCCESS) { + int64_t EndpointSlots; + auto Err = Doc["endpoint-slots"].get().get(EndpointSlots); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-slots option."sv) + } + GraphRef.Params.endpoint_slots = static_cast(EndpointSlots); + } + if (Doc.at_key("endpoint-props").error() == simdjson::SUCCESS) { + auto Err = + Doc["endpoint-props"].get().get(GraphRef.Params.endpoint_props); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-props option."sv) + } + } + if (Doc.at_key("endpoint-metrics").error() == simdjson::SUCCESS) { + auto Err = Doc["endpoint-metrics"].get().get( + GraphRef.Params.endpoint_metrics); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-metrics option."sv) + } + } + if (Doc.at_key("log-json").error() == simdjson::SUCCESS) { + auto Err = Doc["log-json"].get().get(GraphRef.Params.log_json); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the log-json option."sv) + } + } + if (Doc.at_key("slot-save-path").error() == simdjson::SUCCESS) { + std::string_view SlotSavePath; + auto Err = Doc["slot-save-path"].get().get(SlotSavePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the slot-save-path option."sv) + } + GraphRef.Params.slot_save_path = SlotSavePath; + } + if (Doc.at_key("slot-prompt-similarity").error() == simdjson::SUCCESS) { + double SlotPromptSimilarity; + auto Err = + Doc["slot-prompt-similarity"].get().get(SlotPromptSimilarity); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the slot-prompt-similarity option."sv) + } + GraphRef.Params.slot_prompt_similarity = + static_cast(SlotPromptSimilarity); + } + if (Doc.at_key("is-pp-shared").error() == simdjson::SUCCESS) { + auto Err = + Doc["is-pp-shared"].get().get(GraphRef.Params.is_pp_shared); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the is-pp-shared option."sv) + } + } + if (Doc.at_key("n-pp").error() == simdjson::SUCCESS) { + std::string_view NPP; + auto Err = Doc["n-pp"].get().get(NPP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pp option."sv) + } + stringToList(std::string(NPP), GraphRef.Params.n_pp); + } + if (Doc.at_key("n-tg").error() == simdjson::SUCCESS) { + std::string_view NTG; + auto Err = Doc["n-tg"].get().get(NTG); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-tg option."sv) + } + stringToList(std::string(NTG), GraphRef.Params.n_tg); + } + if (Doc.at_key("n-pl").error() == simdjson::SUCCESS) { + std::string_view NPL; + auto Err = Doc["n-pl"].get().get(NPL); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pl option."sv) + } + stringToList(std::string(NPL), GraphRef.Params.n_pl); + } + if (Doc.at_key("context-files").error() == simdjson::SUCCESS) { + std::string_view ContextFiles; + auto Err = Doc["context-files"].get().get(ContextFiles); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the context-files option."sv) + } + } + if (Doc.at_key("chunk-size").error() == simdjson::SUCCESS) { + int64_t ChunkSize; + auto Err = Doc["chunk-size"].get().get(ChunkSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the chunk-size option."sv) + } + GraphRef.Params.chunk_size = static_cast(ChunkSize); + } + if (Doc.at_key("chunk-separator").error() == simdjson::SUCCESS) { + std::string_view ChunkSeparator; + auto Err = + Doc["chunk-separator"].get().get(ChunkSeparator); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the chunk-separator option."sv) + } + GraphRef.Params.chunk_separator = ChunkSeparator; + } + if (Doc.at_key("n-junk").error() == simdjson::SUCCESS) { + int64_t NJunk; + auto Err = Doc["n-junk"].get().get(NJunk); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-junk option."sv) + } + GraphRef.Params.n_junk = static_cast(NJunk); + } + if (Doc.at_key("i-pos").error() == simdjson::SUCCESS) { + int64_t IPos; + auto Err = Doc["i-pos"].get().get(IPos); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the i-pos option."sv) + } + GraphRef.Params.i_pos = static_cast(IPos); + } + if (Doc.at_key("n-out-freq").error() == simdjson::SUCCESS) { + int64_t NOutFreq; + auto Err = Doc["n-out-freq"].get().get(NOutFreq); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-out-freq option."sv) + } + GraphRef.Params.n_out_freq = static_cast(NOutFreq); + } + if (Doc.at_key("n-save-freq").error() == simdjson::SUCCESS) { + int64_t NSaveFreq; + auto Err = Doc["n-save-freq"].get().get(NSaveFreq); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-save-freq option."sv) + } + GraphRef.Params.n_save_freq = static_cast(NSaveFreq); + } + if (Doc.at_key("i-chunk").error() == simdjson::SUCCESS) { + int64_t IChunk; + auto Err = Doc["i-chunk"].get().get(IChunk); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the i-chunk option."sv) + } + GraphRef.Params.i_chunk = static_cast(IChunk); + } + if (Doc.at_key("process-output").error() == simdjson::SUCCESS) { + auto Err = + Doc["process-output"].get().get(GraphRef.Params.process_output); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the process-output option."sv) + } + } + if (Doc.at_key("compute-ppl").error() == simdjson::SUCCESS) { + auto Err = Doc["compute-ppl"].get().get(GraphRef.Params.compute_ppl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the compute-ppl option."sv) + } + } + if (Doc.at_key("n-pca-batch").error() == simdjson::SUCCESS) { + int64_t NPCABatch; + auto Err = Doc["n-pca-batch"].get().get(NPCABatch); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-pca-batch option."sv) + } + GraphRef.Params.n_pca_batch = static_cast(NPCABatch); + } + if (Doc.at_key("n-pca-iterations").error() == simdjson::SUCCESS) { + int64_t NPCAIterations; + auto Err = Doc["n-pca-iterations"].get().get(NPCAIterations); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-pca-iterations option."sv) + } + GraphRef.Params.n_pca_iterations = static_cast(NPCAIterations); + } + if (Doc.at_key("cvector-dimre-method").error() == simdjson::SUCCESS) { + std::string_view CVectorDimreMethod; + auto Err = Doc["cvector-dimre-method"].get().get( + CVectorDimreMethod); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-dimre-method option."sv) + } + if (CVectorDimreMethod == "pca") { + GraphRef.Params.cvector_dimre_method = DIMRE_METHOD_PCA; + } else if (CVectorDimreMethod == "mean") { + GraphRef.Params.cvector_dimre_method = DIMRE_METHOD_MEAN; + } else { + RET_ERROR( + ErrNo::InvalidArgument, + "Invalid value for cvector-dimre-method: must be 'pca' or 'mean'."sv) + } + } + if (Doc.at_key("cvector-outfile").error() == simdjson::SUCCESS) { + std::string_view CVectorOutfile; + auto Err = + Doc["cvector-outfile"].get().get(CVectorOutfile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-outfile option."sv) + } + GraphRef.Params.cvector_outfile = CVectorOutfile; + } + if (Doc.at_key("cvector-positive-file").error() == simdjson::SUCCESS) { + std::string_view CVectorPositiveFile; + auto Err = Doc["cvector-positive-file"].get().get( + CVectorPositiveFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-positive-file option."sv) + } + GraphRef.Params.cvector_positive_file = CVectorPositiveFile; + } + if (Doc.at_key("cvector-negative-file").error() == simdjson::SUCCESS) { + std::string_view CVectorNegativeFile; + auto Err = Doc["cvector-negative-file"].get().get( + CVectorNegativeFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-negative-file option."sv) + } + GraphRef.Params.cvector_negative_file = CVectorNegativeFile; + } + if (Doc.at_key("spm-infill").error() == simdjson::SUCCESS) { + auto Err = Doc["spm-infill"].get().get(GraphRef.Params.spm_infill); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the spm-infill option."sv) + } + } + if (Doc.at_key("out-file").error() == simdjson::SUCCESS) { + std::string_view Outfile; + auto Err = Doc["out-file"].get().get(Outfile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the out-file option."sv) + } + GraphRef.Params.out_file = Outfile; + } + if (Doc.at_key("batched-bench-output-jsonl").error() == simdjson::SUCCESS) { + auto Err = Doc["batched-bench-output-jsonl"].get().get( + GraphRef.Params.batched_bench_output_jsonl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the batched-bench-output-jsonl option."sv) + } + } + + // Check if the model parameters are updated. + if (IsModelUpdated && (PrevNGPULayers != GraphRef.Params.n_gpu_layers || + PrevMainGpu != GraphRef.Params.main_gpu)) { + *IsModelUpdated = true; + } + + // Check if the context parameters are updated. + if (IsContextUpdated && (PrevCtxSize != GraphRef.Params.n_ctx || + PrevThreads != GraphRef.Params.cpuparams.n_threads || + PrevFlashAttn != GraphRef.Params.flash_attn || + PrevEmbedding != GraphRef.Params.embedding)) { + *IsContextUpdated = true; + } + + // Check if the sampler parameters are updated. + if (IsSamplerUpdated && + (PrevTemp != GraphRef.Params.sparams.temp || + PrevTopP != GraphRef.Params.sparams.top_p || + PrevRepeatPenalty != GraphRef.Params.sparams.penalty_repeat || + PrevPresencePenalty != GraphRef.Params.sparams.penalty_present || + PrevFrequencyPenalty != GraphRef.Params.sparams.penalty_freq || + PrevGrammar != GraphRef.Params.sparams.grammar || + PrevSeed != GraphRef.Params.sparams.seed)) { + *IsSamplerUpdated = true; + } + + return ErrNo::Success; +} + +// <<<<<<<< Metadata related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// >>>>>>>> Output related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Generate output metadata. +std::string buildOutputMetadata(Context &CxtRef) noexcept { + return fmt::format(R"({{"input_tokens": {}, )" + R"("output_tokens": {}, )" + R"("llama_build_number": {}, )" + R"("llama_commit": "{}"}})"sv, + CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), + LLAMA_BUILD_NUMBER, LLAMA_COMMIT); +} + +// Generate output embedding. +void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, + const float *Embeddings) noexcept { + // Embedding vector format + // | Content | + // | ----------------------------------- | + // | '{"number_embedding": ' | + // | n_embedding | + // | ', "embedding": ' | + // | '[' | + // | n_embedding*(embedding value %.10f) | + // | (n_embedding-1)*(',') | + // | ']' | + // | '}' | + Embedding = + fmt::format(R"({{"n_embedding": {}, )" + R"("embedding": [{:.10}]}})"sv, + NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); +} +// <<<<<<<< Output related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// >>>>>>>> Compute related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Helper to initialize a llama batch. +struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, + int32_t NSeqMax = 1) noexcept { + struct llama_batch Batch = llama_batch_init( + /* n_tokens_alloc */ static_cast(NTokens), + /* embd */ static_cast(Embd), + /* n_seq_max */ static_cast(NSeqMax)); + std::fill(Batch.n_seq_id, Batch.n_seq_id + NTokens, + static_cast(NSeqMax)); + for (int64_t I = 0; I < NTokens; I++) { + std::fill(Batch.seq_id[I], Batch.seq_id[I] + NSeqMax, 0); + } + std::fill(Batch.logits, Batch.logits + NTokens, false); + return Batch; +} + +// Fill a batch with tokens (smaller than batch size) and position data. +void fillBatch(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, bool IsLogit = false) { + assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); + assuming(Batch.token != nullptr); + assuming(Batch.pos != nullptr); + assuming(Batch.logits != nullptr); + // Fill the batch with pos information. + Batch.n_tokens = static_cast(Tokens.size()); + for (uint32_t I = 0; I < Tokens.size(); I++) { + Batch.token[I] = Tokens[I]; + Batch.pos[I] = NPos + I; + Batch.logits[I] = false; + } + + // Logits for sampling or the end of inputs. + if (IsLogit) { + Batch.logits[Tokens.size() - 1] = true; + } + + // Move the position. + NPos += static_cast(Tokens.size()); +} + +// Evaluate tokens. Construct the batch from tokens and decode. +ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, + bool IsLogits = false) noexcept { + // End the inference if the context is full. + uint32_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + if (NPos + static_cast(Tokens.size()) > NCtx) { + LOG_INFO( + GraphRef.EnableLog, + "evaluateTokens: the context if full ({} / {} tokens). Please increase your "sv + "context size."sv, + NPos + static_cast(Tokens.size()), NCtx) + return ErrNo::ContextFull; + } + + // Loop for decoding batches. Split tokens by batch size. + for (int I = 0; I < static_cast(Tokens.size()); + I += static_cast(GraphRef.Params.n_batch)) { + int NEval = static_cast(Tokens.size()) - I; + if (NEval > static_cast(GraphRef.Params.n_batch)) { + NEval = static_cast(GraphRef.Params.n_batch); + } + + // Fill the batch with pos information. + fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, + Batch, NPos, + IsLogits && I + NEval >= static_cast(Tokens.size())); + + // Decode the batch. + auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); + if (Status == 1) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: try reducing the size of the batch "sv + "or increasing the size of context."sv) + } + if (Status < 0) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: internal fatal error. Please open "sv + "an issue on GitHub."sv) + } + } + + return ErrNo::Success; +} + +// Clear the context and reset the sampler. +void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext"sv) + llama_kv_cache_clear(GraphRef.LlamaContext.get()); + common_sampler_reset(CxtRef.LlamaSampler.get()); + CxtRef.NPos = 0; + CxtRef.LlamaOutputTokens.clear(); + CxtRef.LlamaOutputs.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext...Done"sv) +} + +// Evaluate the input tokens. Clear all inputs on success. +ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, + std::string_view LogPrefix) noexcept { + // Check if the input is set before setting up the context. + if (CxtRef.LlamaInputs.size() == 0) { + RET_ERROR(ErrNo::InvalidArgument, "{}: llama input is not set!"sv, + LogPrefix) + } + + // Get the context size. + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const uint64_t MaxTokensListSize = NCtx - 4; + // Return value. + auto ReturnCode = ErrNo::Success; + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + RET_ERROR(ErrNo::PromptTooLong, + "{}: the prompt is too long. Your input has {} tokens. "sv + "Please reduce it to {} tokens."sv, + LogPrefix, CxtRef.LlamaInputs.size(), MaxTokensListSize) + } + + // Evaluate input tokens. + ReturnCode = + evaluateTokens(Span(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.size()), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) + } + + return ErrNo::Success; +} + +// Sample and get the output token. +ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, + bool IsSingleTokenMode = false) noexcept { + // Use idx = -1 to sample the next token. + const llama_token Id = common_sampler_sample( + CxtRef.LlamaSampler.get(), GraphRef.LlamaContext.get(), /* idx */ -1); + common_sampler_accept(CxtRef.LlamaSampler.get(), Id, + /* accept_grammar */ true); + + // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); + std::string OutputString = + common_token_to_piece(GraphRef.LlamaContext.get(), Id); + CxtRef.LlamaOutputs.insert(CxtRef.LlamaOutputs.end(), OutputString.begin(), + OutputString.end()); + // In single token mode, we do not handle StreamStdout and ReversePrompt. + if (!IsSingleTokenMode) { + // When setting StreamStdout, we print the output to stdout. + if (CxtRef.Conf.StreamStdout) { + fmt::print("{}"sv, + common_token_to_piece(GraphRef.LlamaContext.get(), Id)); + std::fflush(stdout); + } + // Break if reverse prompt is found. + if (!CxtRef.Conf.ReversePrompt.empty() && + std::string(CxtRef.LlamaOutputs.begin(), CxtRef.LlamaOutputs.end()) + .find(CxtRef.Conf.ReversePrompt) != std::string::npos) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: reverse prompt found."sv) + return ErrNo::EndOfSequence; + } + } + + // Deal with end of text token. + // Only stop on EOS if GraphRef.Params.sparams.ignore_eos is false. + if (!GraphRef.Params.sparams.ignore_eos) { + if (llama_token_is_eog(GraphRef.LlamaModel.get(), Id)) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: EOS token found."sv) + return ErrNo::EndOfSequence; + } + } + // Evaluate the output token. + return evaluateTokens(Span(&Id, 1), GraphRef, + CxtRef.OutputBatch, CxtRef.NPos, true); +} + +// TODO: Merge into compute. +Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) + + const llama_token SepTokenId = llama_token_sep(GraphRef.LlamaModel.get()); + if (SepTokenId > -1) { + if (CxtRef.LlamaInputs.size() > 0 && + CxtRef.LlamaInputs.back() != SepTokenId) { + LOG_WARN( + "getEmbedding: last token in the prompt is not SEP, "sv + "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv + "header."sv) + } + } + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > + GraphRef.Params.n_batch) { + RET_ERROR( + ErrNo::PromptTooLong, + "getEmbedding: the prompt is too long. Your input has {} tokens exceeds batch "sv + "size {}. Please reduce the input size or increase your batch-size."sv, + CxtRef.LlamaInputs.size(), GraphRef.Params.n_batch) + } + + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "getEmbedding"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + // Main prediction loop. + const struct llama_model *LlamaModel = + llama_get_model(GraphRef.LlamaContext.get()); + const int32_t NEmbd = llama_n_embd(LlamaModel); + std::vector Embeddings(NEmbd); + + for (int I = 0; I < CxtRef.LlamaBatch.n_tokens; I++) { + if (!CxtRef.LlamaBatch.logits[I]) { + continue; + } + + // Try to get sequence embeddings. + auto *Embd = llama_get_embeddings_seq(GraphRef.LlamaContext.get(), + CxtRef.LlamaBatch.seq_id[I][0]); + if (Embd == nullptr) { + Embd = llama_get_embeddings_ith(GraphRef.LlamaContext.get(), I); + if (Embd == nullptr) { + LOG_ERROR("getEmbedding: failed to get embeddings for token {}"sv, I); + continue; + } + } + + // Normalize the embeddings. + common_embd_normalize(Embd, Embeddings.data(), NEmbd, + static_cast(CxtRef.Conf.EmbdNormalize)); + } + + std::string EmbeddingString; + buildOutputEmbedding(EmbeddingString, NEmbd, Embeddings.data()); + CxtRef.LlamaOutputs = + std::vector(EmbeddingString.begin(), EmbeddingString.end()); + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) + return ErrNo::Success; +} + +// <<<<<<<< Compute related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +} // namespace + +Expect load(WasiNNEnvironment &Env, Span> Builders, + [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { + if (Builders.empty()) { + RET_ERROR(ErrNo::InvalidArgument, + "Invalid builders size, builders size must be > 0."); + } + + // Add a graph + const uint32_t GId = Env.newGraph(Backend::BitNet); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Initialize the plugin parameters. + GraphRef.EnableLog = false; + GraphRef.EnableDebugLog = false; + const common_params CommonParamsDefault; + GraphRef.Params = CommonParamsDefault; + GraphRef.Params.n_keep = 0; + GraphRef.Params.n_chunks = -1; + GraphRef.Params.n_parallel = 1; + GraphRef.Params.grp_attn_n = 1; + GraphRef.Params.grp_attn_w = 512; + GraphRef.Params.n_print = -1; + GraphRef.Params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; + // Initialize the model parameters. + llama_model_params ModelParamsDefault = llama_model_default_params(); + GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; + GraphRef.Params.mmproj = ""sv; + GraphRef.Params.warmup = false; + // Initialize the context parameters. + llama_context_params ContextParamsDefault = llama_context_default_params(); + GraphRef.Params.n_ctx = ContextParamsDefault.n_ctx; + GraphRef.Params.n_batch = ContextParamsDefault.n_batch; + GraphRef.Params.n_ubatch = ContextParamsDefault.n_ubatch; + GraphRef.Params.cpuparams.n_threads = ContextParamsDefault.n_threads_batch; + GraphRef.Params.cpuparams_batch.n_threads = + ContextParamsDefault.n_threads_batch; + GraphRef.Params.rope_scaling_type = ContextParamsDefault.rope_scaling_type; + GraphRef.Params.pooling_type = ContextParamsDefault.pooling_type; + GraphRef.Params.attention_type = ContextParamsDefault.attention_type; + GraphRef.Params.rope_freq_base = ContextParamsDefault.rope_freq_base; + GraphRef.Params.rope_freq_scale = ContextParamsDefault.rope_freq_scale; + GraphRef.Params.yarn_ext_factor = ContextParamsDefault.yarn_ext_factor; + GraphRef.Params.yarn_attn_factor = ContextParamsDefault.yarn_attn_factor; + GraphRef.Params.yarn_beta_fast = ContextParamsDefault.yarn_beta_fast; + GraphRef.Params.yarn_beta_slow = ContextParamsDefault.yarn_beta_slow; + GraphRef.Params.yarn_orig_ctx = ContextParamsDefault.yarn_orig_ctx; + GraphRef.Params.defrag_thold = ContextParamsDefault.defrag_thold; + GraphRef.Params.cb_eval = ContextParamsDefault.cb_eval; + GraphRef.Params.cb_eval_user_data = ContextParamsDefault.cb_eval_user_data; + GraphRef.Params.embedding = ContextParamsDefault.embeddings; + GraphRef.Params.no_kv_offload = !ContextParamsDefault.offload_kqv; + GraphRef.Params.flash_attn = ContextParamsDefault.flash_attn; + GraphRef.Params.no_perf = ContextParamsDefault.no_perf; + + // Initialize the sampling parameters. + const common_sampler_params SamplerParamsDefault; + GraphRef.Params.sparams = SamplerParamsDefault; + + // Initialize the config parameters. + GraphRef.Conf.StreamStdout = false; + GraphRef.Conf.EmbdNormalize = + static_cast(CommonParamsDefault.embd_normalize); + GraphRef.Conf.NPredict = ContextParamsDefault.n_ctx; + GraphRef.Conf.ReversePrompt = ""sv; + + // Set llama log callback. + llama_log_set(llamaLogCallback, &GraphRef); + LOG_DEBUG(GraphRef.EnableDebugLog, "load start."sv) + + // If the graph builder length is greater than 1, builder[1] contains the + // metadata. + if (Builders.size() > 1) { + const std::string Metadata( + reinterpret_cast(Builders[1].data()), Builders[1].size()); + // Ignore context or model updates when initializing the graph. + auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); + if (Res != ErrNo::Success) { + Env.deleteGraph(GId); + RET_ERROR(Res, "load: Failed to parse metadata."sv); + } + } + + LOG_INFO(GraphRef.EnableLog, "LLAMA_COMMIT {}"sv, LLAMA_COMMIT) + LOG_INFO(GraphRef.EnableLog, "LLAMA_BUILD_NUMBER {}"sv, LLAMA_BUILD_NUMBER) + + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path."sv) + const auto &Weight = Builders[0]; + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + + if (BinModel.substr(0, 8) == "preload:"sv) { + GraphRef.Params.model = std::string(BinModel.substr(8)); + } else { + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Model path not found in nn-preload, write model into "sv + "a tmpfile."sv) + GraphRef.Params.model = "bitnet-model.bin"sv; + std::ofstream TempFile(GraphRef.Params.model, + std::ios::out | std::ios::binary | std::ios::trunc); + if (!TempFile) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "Failed to create temp model file."sv) + } + TempFile.write(BinModel.data(), BinModel.size()); + TempFile.close(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Write model into a tmpfile...Done"sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path...Done"sv) + + // Check if the model exists. + if (!std::filesystem::exists( + std::filesystem::u8path(GraphRef.Params.model))) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::ModelNotFound, + "load: Model file not found at path: '{}'."sv, + GraphRef.Params.model) + } + + LOG_INFO(GraphRef.EnableLog, "load: Loading model from '{}'."sv, + GraphRef.Params.model) + + // Initialize model parameters. + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize model with given parameters."sv) + + llama_backend_init(); + llama_numa_init(GraphRef.Params.numa); + + // Initialize the llama model and context. + common_init_result LlamaInit = common_init_from_params(GraphRef.Params); + GraphRef.LlamaModel.reset(LlamaInit.model); + GraphRef.LlamaContext.reset(LlamaInit.context); + + if (GraphRef.LlamaModel == nullptr) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "load: Unable to init model."sv) + } + if (GraphRef.LlamaContext == nullptr) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "load: Unable to init context."sv) + } + + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize model with given parameters...Done"sv) + + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx"sv) + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + auto &CxtRef = Env.NNContext[ContextId].get(); + + LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, + llama_print_system_info()) + + // Allocate the batch for input string prompt tokens. + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + + // Allocate the batch for single-token output sampling. + CxtRef.OutputBatch = allocBatch(1); + + // Allocate the sampler + CxtRef.LlamaSampler.reset( + common_sampler_init(GraphRef.LlamaModel.get(), GraphRef.Params.sparams)); + + Env.NNContext[ContextId].setReady(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput"sv) + + // Handle Metadata at Index 1 + if (Index == 1) { + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing"sv) + bool IsModelUpdated = false; + bool IsContextUpdated = false; + bool IsSamplerUpdated = false; + const std::string Metadata( + reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + + auto Res = parseMetadata(GraphRef, CxtRef.Conf, Metadata, &IsModelUpdated, + &IsContextUpdated, &IsSamplerUpdated); + if (Res != ErrNo::Success) { + RET_ERROR(Res, "setInput: failed to parse metadata."sv) + } + + if (IsModelUpdated || GraphRef.LlamaModel == nullptr) { + // The llama model may be nullptr if set_input updated the model params + // last time. Therefore, in addition to updated model params, we should + // reload the llama model if the model is nullptr. + LOG_INFO(GraphRef.EnableLog, + "setInput: Reloading model due to parameter change"sv) + + // Prepare model parameters for the reload. + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = + static_cast(GraphRef.Params.n_gpu_layers); + ModelParams.main_gpu = static_cast(GraphRef.Params.main_gpu); + + // Free all resources that depend on the old model. + GraphRef.LlamaModel.reset(); + + // Due to the model change, the context and sampler should also be + // reloaded. The new context and sampler will be created in the next + // block. + GraphRef.LlamaContext.reset(); + if (CxtRef.LlamaSampler) { + CxtRef.LlamaSampler.reset(); + CxtRef.LlamaSampler = nullptr; + } + + // Attempt to load the model from file with new parameters. + GraphRef.LlamaModel.reset(llama_load_model_from_file( + GraphRef.Params.model.c_str(), ModelParams)); + if (GraphRef.LlamaModel == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) + } + } + + // Reload context if its parameters changed OR if it was cleared by a model + // reload. + if (IsContextUpdated || GraphRef.LlamaContext == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Reloading llama context due to parameter change."sv) + GraphRef.LlamaContext.reset(); + llama_context_params CtxParams = + common_context_params_to_llama(GraphRef.Params); + GraphRef.LlamaContext.reset( + llama_new_context_with_model(GraphRef.LlamaModel.get(), CtxParams)); + if (GraphRef.LlamaContext == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init context."sv) + } + } + + // Re-initialize sampler if its parameters changed OR if it was cleared. + if (IsSamplerUpdated || CxtRef.LlamaSampler == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Re-initializing sampler due to parameter change."sv); + CxtRef.LlamaSampler.reset(common_sampler_init(GraphRef.LlamaModel.get(), + GraphRef.Params.sparams)); + if (CxtRef.LlamaSampler == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init sampler."sv) + } + } + + // Re-allocate batch if the batch size changed. + if (CxtRef.CurrentBatchSize != GraphRef.Params.n_batch) { + LOG_INFO(GraphRef.EnableLog, + "Re-allocating batch due to n_batch change."); + llama_batch_free(CxtRef.LlamaBatch); + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + if (!CxtRef.LlamaBatch.token) { + RET_ERROR(ErrNo::InvalidArgument, "Failed to re-allocate llama_batch."); + } + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + } + + Env.NNGraph[CxtRef.GraphId].setReady(); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: metadata processing...Done"); + return ErrNo::Success; + } + + if (Index != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "Only prompt (index 0) and metadata (index 1) are supported."); + } + + // Check that the graph is valid after reloading during the previous + // set_input. + if (!Env.NNGraph[CxtRef.GraphId].isReady()) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: Graph is invalid. Please reload again by passing metadata "sv + "in set_input or unload graph."sv) + } + + LOG_DEBUG(GraphRef.EnableLog, "setInput: Clearing KV cache for new prompt."sv) + llama_kv_cache_clear(GraphRef.LlamaContext.get()); + LOG_DEBUG(GraphRef.EnableLog, + "setInput: Clearing KV cache for new prompt...done"sv) + + // Check tensor type. + if (Tensor.RType != TensorType::U8) { + RET_ERROR(ErrNo::InvalidArgument, + "Input tensor must be a UTF-8 string (U8)."); + } + + // Tokenize the new prompt. + const std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) + CxtRef.LlamaInputs = + common_tokenize(GraphRef.LlamaContext.get(), Prompt, + llama_add_bos_token(GraphRef.LlamaModel.get()), true); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt...Done"sv) + + // Get the number of input tokens for the metadata. + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + + // Reset state for the compute loop. + CxtRef.ComputeSingleStarted = false; + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}"sv, Index) + + // Handle Metadata Output at Index 1 + if (Index == 1) { + const std::string Metadata = buildOutputMetadata(CxtRef); + const size_t BytesToCopy = + std::min(static_cast(OutBuffer.size()), Metadata.length()); + std::copy_n(Metadata.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: Metadata (Index 1)...Done"sv) + return ErrNo::Success; + } + + const size_t BytesToCopy = std::min(static_cast(OutBuffer.size()), + CxtRef.LlamaOutputs.size()); + std::copy_n(CxtRef.LlamaOutputs.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = CxtRef.LlamaOutputs.size(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: Text (Index 0)...Done"sv) + return ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "compute"); + + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + + if (GraphRef.Params.embedding) { + return getEmbedding(GraphRef, CxtRef); + } + + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + // Main prediction loop. + LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) + int64_t NPredict = CxtRef.Conf.NPredict; + if (NPredict < 0) { + NPredict = INT32_MAX; + } + + while (NPredict > 0) { + ReturnCode = sampleOutput(GraphRef, CxtRef); + if (ReturnCode != ErrNo::Success) { + break; + } + NPredict--; + } + + if (ReturnCode == ErrNo::EndOfSequence || ReturnCode == ErrNo::ContextFull) { + LOG_INFO(GraphRef.EnableLog, "compute finished with status: {}."sv, + static_cast(ReturnCode)) + return ErrNo::Success; + } + + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: enter main prediction loop...Done"sv) + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler.get()); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "compute...Done") + return ReturnCode; +} + +Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}"sv, Index) + + // Metadata Output at Index 1 + if (Index == 1) { + const std::string Metadata = buildOutputMetadata(CxtRef); + const size_t BytesToCopy = + std::min(static_cast(OutBuffer.size()), Metadata.length()); + std::copy_n(Metadata.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutputSingle: Metadata (Index 1)...Done"sv) + return ErrNo::Success; + } + + if (CxtRef.LlamaOutputTokens.empty()) { + BytesWritten = 0; + return ErrNo::Success; + } + + const std::string LastTokenStr = common_token_to_piece( + GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); + + const size_t BytesToCopy = + std::min(static_cast(OutBuffer.size()), LastTokenStr.length()); + std::copy_n(LastTokenStr.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = LastTokenStr.length(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: Text (Index 0)...Done"sv) + return ErrNo::Success; +} + +Expect computeSingle(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle"sv) + + auto ReturnCode = ErrNo::Success; + if (!CxtRef.ComputeSingleStarted) { + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + ReturnCode = evaluateInput(GraphRef, CxtRef, "computeSingle"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + CxtRef.ComputeSingleStarted = true; + } + + // Main prediction process. + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process"sv) + ReturnCode = sampleOutput(GraphRef, CxtRef, /* IsSingleTokenMode */ true); + if (ReturnCode != ErrNo::Success) { + CxtRef.ComputeSingleStarted = false; + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process...Done"sv) + // End of main predict process. + + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle...Done"sv) + return ReturnCode; +} + +Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle"); + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler.get()); + } + + // Reset the llama sampler. + common_sampler_reset(CxtRef.LlamaSampler.get()); + + // Clear the outputs. + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens"sv) + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens...Done"sv) + + CxtRef.NPos = 0; + CxtRef.ComputeSingleStarted = false; + + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle...Done"sv) + return ErrNo::Success; +} + +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + + if (GraphId >= Env.NNGraph.size()) { + return ErrNo::Success; + } + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.LlamaModel == nullptr) { + return ErrNo::Success; + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "unload"sv) + + Env.deleteGraph(GraphId); + + LOG_DEBUG(GraphRef.EnableDebugLog, "unload...Done"sv) + return ErrNo::Success; +} + +Expect finalizeExecCtx(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finalizeExecCtx"sv) + + CxtRef.LlamaSampler.reset(); + llama_batch_free(CxtRef.LlamaBatch); + llama_batch_free(CxtRef.OutputBatch); + Env.deleteContext(ContextId); + + LOG_DEBUG(GraphRef.EnableDebugLog, "finalizeExecCtx...Done"sv) + return ErrNo::Success; +} + +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] BitNet backend is not built. Please build with " + "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET=ON."sv); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect getOutputSingle(WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect computeSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finiSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::BitNet diff --git a/plugins/wasi_nn/wasinn_bitnet.h b/plugins/wasi_nn/wasinn_bitnet.h new file mode 100644 index 00000000..6fba1f39 --- /dev/null +++ b/plugins/wasi_nn/wasinn_bitnet.h @@ -0,0 +1,140 @@ +#pragma once + +#include "plugin/plugin.h" +#include "wasinntypes.h" +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::BitNet { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET + +struct LlamaModelDeleter { + void operator()(llama_model *Ptr) const { + if (Ptr) { + llama_free_model(Ptr); + } + } +}; +struct LlamaContextDeleter { + void operator()(llama_context *Ptr) const { + if (Ptr) { + llama_free(Ptr); + } + } +}; +struct CommonSamplerDeleter { + void operator()(common_sampler *Ptr) const { + if (Ptr) { + common_sampler_free(Ptr); + } + } +}; + +using LlamaModelPtr = std::unique_ptr; +using LlamaContextPtr = std::unique_ptr; +using CommonSamplerPtr = std::unique_ptr; + +enum class EmbdNormalizeType : int32_t { + // From: https://github.com/ggerganov/llama.cpp/blob/master/common/common.h + None = -1, + MaxAbsolute = 0, + Taxicab = 1, + Euclidean = 2, + PNorm = 3, +}; + +struct LocalConfig { + int64_t NPredict = -1; + bool StreamStdout = false; + std::string ReversePrompt; + EmbdNormalizeType EmbdNormalize = EmbdNormalizeType::Euclidean; +}; + +struct Graph { + bool EnableLog = false; + bool EnableDebugLog = false; + common_params Params; + LlamaModelPtr LlamaModel = nullptr; + LlamaContextPtr LlamaContext = nullptr; + LocalConfig Conf; +}; + +struct Context { +public: + Context(uint32_t GId, Graph &G) noexcept : GraphId(GId), Conf(G.Conf) {} + + uint32_t GraphId; + bool ComputeSingleStarted = false; + + int32_t NPos = 0; + std::vector LlamaInputs; + uint64_t LlamaNInputs = 0; + std::vector LlamaOutputTokens; + std::vector LlamaOutputs; + CommonSamplerPtr LlamaSampler = nullptr; + int64_t CurrentBatchSize = 0; + struct llama_batch LlamaBatch; + struct llama_batch OutputBatch; + + LocalConfig Conf; +}; + +#else + +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect getOutputSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; + +Expect computeSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect finiSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; + +} // namespace WasmEdge::Host::WASINN::BitNet diff --git a/plugins/wasi_nn/wasinn_chattts.cpp b/plugins/wasi_nn/wasinn_chattts.cpp new file mode 100644 index 00000000..5937d735 --- /dev/null +++ b/plugins/wasi_nn/wasinn_chattts.cpp @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_chattts.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +#include "simdjson.h" + +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__WIN32__) && \ + !defined(__TOS_WIN__) && !defined(__WINDOWS__) +#include +#endif +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::ChatTTS { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +#if defined(_WIN32) || defined(_WIN64) || defined(__WIN32__) || \ + defined(__TOS_WIN__) || defined(__WINDOWS__) +HINSTANCE SharedLib = LoadLibrary(PYTHON_LIB_PATH); +#else +void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); +#endif +Expect load(WASINN::WasiNNEnvironment &Env, + Span>, WASINN::Device, + uint32_t &GraphId) noexcept { + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::ChatTTS); + auto &GraphRef = Env.NNGraph[GId].get(); + // Initialize the plugin parameters. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: Load."sv); + } + + // Create Model class + if (!Py_IsInitialized()) { + Py_Initialize(); + if (PyGILState_Check()) { + PyEval_SaveThread(); + } + } + GIL Lock; + if (GraphRef.ChatTTSModule == nullptr) { + GraphRef.ChatTTSModule = PyImport_ImportModule("ChatTTS"); + if (GraphRef.ChatTTSModule == nullptr) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not find ChatTTS library."sv); + Env.deleteGraph(GId); + return WASINN::ErrNo::RuntimeError; + } + } + if (GraphRef.Chat == nullptr) { + PyObject *ChatFunction = + PyObject_GetAttrString(GraphRef.ChatTTSModule, "Chat"); + if (ChatFunction == nullptr || !PyCallable_Check(ChatFunction)) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not find Chat class in ChatTTS."sv); + Env.deleteGraph(GId); + return WASINN::ErrNo::RuntimeError; + } + GraphRef.Chat = PyObject_CallObject(ChatFunction, nullptr); + Py_XDECREF(ChatFunction); + if (GraphRef.Chat == nullptr) { + spdlog::error("[WASI-NN] ChatTTS backend: Can not create chat."sv); + Env.deleteGraph(GId); + return WASINN::ErrNo::RuntimeError; + } + PyObject *LoadMethod = PyObject_GetAttrString(GraphRef.Chat, "load"); + if (LoadMethod == nullptr || !PyCallable_Check(LoadMethod)) { + spdlog::error("[WASI-NN] ChatTTS backend: Can not load chat."sv); + Env.deleteGraph(GId); + return WASINN::ErrNo::RuntimeError; + } + PyObject *Value = PyObject_CallObject(LoadMethod, nullptr); + Py_XDECREF(Value); + Py_XDECREF(LoadMethod); + } + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + if (!Py_IsInitialized()) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Model has been released, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (!Py_IsInitialized()) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Model has been released, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: setInput"sv); + } + if (Index == 0) { + // Set the input. + std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + CxtRef.Inputs.clear(); + CxtRef.Inputs = Prompt; + return WASINN::ErrNo::Success; + } else if (Index == 1) { + // Set metadata. + std::string Metadata = std::string( + reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] ChatTTS backend: Parse metadata error"sv); + return ErrNo::InvalidEncoding; + } + GIL Lock; + // Handle Refine Text Params + PyObject *PromptObj = nullptr; + if (Doc.at_key("prompt").error() == simdjson::SUCCESS) { + std::string_view PromptView; + auto Err = Doc["prompt"].get().get(PromptView); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the prompt option."sv); + return ErrNo::InvalidArgument; + } + PromptObj = PyUnicode_FromString(std::string(PromptView).c_str()); + } + if (PromptObj != nullptr) { + PyObject *Args = PyTuple_New(0); + PyObject *Kwargs = PyDict_New(); + PyDict_SetItemString(Kwargs, "prompt", PromptObj); + PyObject *RefineTextParamsFun = + PyObject_GetAttrString(GraphRef.Chat, "RefineTextParams"); + GraphRef.ParamsRefineText = + PyObject_Call(RefineTextParamsFun, Args, Kwargs); + Py_XDECREF(PromptObj); + Py_XDECREF(Args); + Py_XDECREF(Kwargs); + Py_XDECREF(RefineTextParamsFun); + } + // Handle Infer Code Params + PyObject *InferKwargs = PyDict_New(); + if (Doc.at_key("temperature").error() == simdjson::SUCCESS) { + double Temperature; + auto Err = Doc["temperature"].get().get(Temperature); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the temperature option."sv); + return ErrNo::InvalidArgument; + } + PyObject *TemperatureObject = PyFloat_FromDouble(Temperature); + PyDict_SetItemString(InferKwargs, "temperature", TemperatureObject); + Py_XDECREF(TemperatureObject); + } + if (Doc.at_key("top_K").error() == simdjson::SUCCESS) { + double TopK; + auto Err = Doc["top_K"].get().get(TopK); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the topK option."sv); + return ErrNo::InvalidArgument; + } + PyObject *TopKObject = PyFloat_FromDouble(TopK); + PyDict_SetItemString(InferKwargs, "top_K", TopKObject); + Py_XDECREF(TopKObject); + } + if (Doc.at_key("top_P").error() == simdjson::SUCCESS) { + double TopP; + auto Err = Doc["top_P"].get().get(TopP); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the temperature option."sv); + return ErrNo::InvalidArgument; + } + PyObject *TopPObject = PyFloat_FromDouble(TopP); + PyDict_SetItemString(InferKwargs, "top_P", TopPObject); + Py_XDECREF(TopPObject); + } + if (Doc.at_key("spk_emb").error() == simdjson::SUCCESS) { + std::string_view SpkEmb; + auto Err = Doc["spk_emb"].get().get(SpkEmb); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the spk_emb option."sv); + return ErrNo::InvalidArgument; + } + if (SpkEmb == "random") { + PyObject *SampleRandomSpeaker = + PyObject_GetAttrString(GraphRef.Chat, "sample_random_speaker"); + PyObject *Spk = PyObject_CallNoArgs(SampleRandomSpeaker); + PyDict_SetItemString(InferKwargs, "spk_emb", Spk); + Py_XDECREF(SampleRandomSpeaker); + Py_XDECREF(Spk); + } else { + PyObject *Spk = PyUnicode_FromString(std::string(SpkEmb).c_str()); + PyDict_SetItemString(InferKwargs, "spk_emb", Spk); + Py_XDECREF(Spk); + } + } + if (PyDict_Size(InferKwargs) != 0) { + PyObject *Args = PyTuple_New(0); + PyObject *InferCodeParams = + PyObject_GetAttrString(GraphRef.Chat, "InferCodeParams"); + GraphRef.ParamsInferCode = + PyObject_Call(InferCodeParams, Args, InferKwargs); + Py_XDECREF(Args); + Py_XDECREF(InferCodeParams); + } + Py_XDECREF(InferKwargs); + return WASINN::ErrNo::Success; + } + return WASINN::ErrNo::InvalidArgument; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: getOutput"sv); + } + if (Index == 0) { + std::copy_n(CxtRef.Outputs.data(), CxtRef.Outputs.size(), OutBuffer.data()); + BytesWritten = CxtRef.Outputs.size(); + return WASINN::ErrNo::Success; + } + return WASINN::ErrNo::InvalidArgument; +} + +Expect compute(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + if (!Py_IsInitialized()) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Model has been released, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: compute"sv); + } + if (CxtRef.Inputs.size() == 0) { + spdlog::error("[WASI-NN] ChatTTS backend: Input is not set!"sv); + return ErrNo::InvalidArgument; + } + GIL Lock; + PyObject *InputStr = PyUnicode_FromString(CxtRef.Inputs.c_str()); + PyObject *InferMethod = PyObject_GetAttrString(GraphRef.Chat, "infer"); + PyObject *Result = nullptr; + if (InferMethod == nullptr || !PyCallable_Check(InferMethod)) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not find infer method in Chat."sv); + PyErr_Print(); + Py_XDECREF(InferMethod); + return WASINN::ErrNo::RuntimeError; + } + if (GraphRef.ParamsRefineText == nullptr && + GraphRef.ParamsInferCode == nullptr) { + PyObject *Args = PyTuple_Pack(1, InputStr); + Result = PyObject_CallObject(InferMethod, Args); + Py_XDECREF(Args); + } else { + PyObject *Args = PyTuple_New(0); + PyObject *Kwargs = PyDict_New(); + PyDict_SetItemString(Kwargs, "text", InputStr); + if (GraphRef.ParamsRefineText != nullptr) { + PyDict_SetItemString(Kwargs, "params_refine_text", + GraphRef.ParamsRefineText); + } + if (GraphRef.ParamsInferCode != nullptr) { + PyDict_SetItemString(Kwargs, "params_infer_code", + GraphRef.ParamsInferCode); + } + Result = PyObject_Call(InferMethod, Args, Kwargs); + Py_XDECREF(Args); + Py_XDECREF(Kwargs); + } + if (Result != nullptr) { + PyObject *Index = PyLong_FromLong(0); + PyObject *Wav0 = PyObject_GetItem(Result, Index); + Py_XDECREF(Index); + PyObject *BytesObj = PyObject_CallMethod(Wav0, "tobytes", nullptr); + Py_XDECREF(Wav0); + char *Bytes = PyBytes_AsString(BytesObj); + Py_ssize_t size = PyBytes_Size(BytesObj); + CxtRef.Outputs = std::vector(Bytes, Bytes + size); + Py_XDECREF(BytesObj); + } else { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not get output from infer method."sv); + Py_XDECREF(InputStr); + Py_XDECREF(InferMethod); + return WASINN::ErrNo::RuntimeError; + } + Py_XDECREF(Result); + Py_XDECREF(InputStr); + Py_XDECREF(InferMethod); + return WASINN::ErrNo::Success; +} + +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: start unload."sv); + } + if (Py_IsInitialized()) { + GIL Lock; + Py_XDECREF(GraphRef.ParamsRefineText); + Py_XDECREF(GraphRef.ParamsInferCode); + Py_XDECREF(GraphRef.Chat); + Py_XDECREF(GraphRef.ChatTTSModule); + GraphRef.ParamsRefineText = nullptr; + GraphRef.ParamsInferCode = nullptr; + GraphRef.Chat = nullptr; + GraphRef.ChatTTSModule = nullptr; + } + Env.deleteGraph(GraphId); + return WASINN::ErrNo::Success; +} + +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] ChatTTS backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"ChatTTS\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif + +} // namespace WasmEdge::Host::WASINN::ChatTTS diff --git a/plugins/wasi_nn/wasinn_chattts.h b/plugins/wasi_nn/wasinn_chattts.h new file mode 100644 index 00000000..8adf8429 --- /dev/null +++ b/plugins/wasi_nn/wasinn_chattts.h @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::ChatTTS { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +class GIL { +private: + PyGILState_STATE GState; + +public: + GIL() : GState(PyGILState_Ensure()) {} + ~GIL() { PyGILState_Release(GState); } + GIL(const GIL &) = delete; + GIL &operator=(const GIL &) = delete; +}; +struct Graph { + bool EnableDebugLog = false; + Graph() noexcept { + if (!Py_IsInitialized()) { + Py_Initialize(); + if (PyGILState_Check()) { + PyEval_SaveThread(); + } + } + } + ~Graph() noexcept { + if (Py_IsInitialized()) { + GIL Lock; + Py_XDECREF(Chat); + Py_XDECREF(ChatTTSModule); + } + } + PyObject *Chat = nullptr; + PyObject *ChatTTSModule = nullptr; + PyObject *ParamsRefineText = nullptr; + PyObject *ParamsInferCode = nullptr; +}; +struct Context { + Context(uint32_t Gid, Graph &) noexcept : GraphId(Gid) {} + uint32_t GraphId; + std::string Inputs; + std::vector Outputs; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +} // namespace WasmEdge::Host::WASINN::ChatTTS diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp new file mode 100644 index 00000000..4ca60e58 --- /dev/null +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -0,0 +1,676 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_mlx.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX + +#include "MLX/mlx/base.h" +#include "MLX/model/converter.h" +#include "MLX/model/gemma3/gemma3.h" +#include "MLX/model/llm/registry.h" +#include "MLX/model/llm/transformer.h" +#include "MLX/model/utils.h" +#include "MLX/model/whisper/whisper.h" +#include "MLX/model/whisper_transcribe.h" +#include "MLX/prompt/prompt.h" +#include +#include + +#include + +#include "host/wasi/vfs_io.h" +#endif + +namespace WasmEdge::Host::WASINN::MLX { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX + +mx::array fromBytes(const Span &Bytes) { + if (Bytes.size() < 9) { + spdlog::error( + "[WASI-NN] MLX backend: Tensor data must be at least 9 bytes long, current size: {}."sv, + Bytes.size()); + return mx::array({0.0f}); + } + + size_t Offset = 0; + + uint8_t RtypeValue = Bytes[Offset]; + Offset += 1; + + uint32_t DimBufLen; + std::memcpy(&DimBufLen, &Bytes[Offset], 4); + Offset += 4; + + std::vector Shape; + for (size_t I = 0; I < DimBufLen; I += 4) { + uint32_t Dim; + std::memcpy(&Dim, &Bytes[Offset + I], 4); + Shape.push_back(static_cast(Dim)); + } + Offset += DimBufLen; + + uint32_t DataBufLen; + std::memcpy(&DataBufLen, &Bytes[Offset], 4); + Offset += 4; + + const void *DataPtr = &Bytes[Offset]; + switch (RtypeValue) { + case 0: { // F16 + return mx::array(static_cast(DataPtr), Shape, + mx::float16); + } + case 1: { // F32 + return mx::array(static_cast(DataPtr), Shape, mx::float32); + } + case 2: { // F64 + return mx::array(static_cast(DataPtr), Shape, mx::float64); + } + case 3: { // U8 + return mx::array(static_cast(DataPtr), Shape, mx::uint8); + } + case 4: { // I32 + return mx::array(static_cast(DataPtr), Shape, mx::int32); + } + case 5: { // I64 + return mx::array(static_cast(DataPtr), Shape, mx::int64); + } + default: + spdlog::error("[WASI-NN] MLX backend: Unsupported rtype: {}", RtypeValue); + return mx::array({0.0f}); + } +} + +std::vector toBytes(const mx::array &Arr) { + std::vector Result; + + uint8_t RtypeValue; + switch (Arr.dtype()) { + case mx::float16: + RtypeValue = 0; + break; + case mx::float32: + RtypeValue = 1; + break; + case mx::float64: + RtypeValue = 2; + break; + case mx::uint8: + RtypeValue = 3; + break; + case mx::int32: + RtypeValue = 4; + break; + case mx::int64: + RtypeValue = 5; + break; + default: + spdlog::error( + "[WASI-NN] MLX backend: Unsupported dtype to convert to Processor Tensor"sv); + return Result; + } + Result.push_back(RtypeValue); + + std::vector DimBuf; + auto Shape = Arr.shape(); + for (int Dim : Shape) { + uint32_t DimData = static_cast(Dim); + const uint8_t *DimBytes = reinterpret_cast(&DimData); + DimBuf.insert(DimBuf.end(), DimBytes, DimBytes + 4); + } + + uint32_t DimBufLen = static_cast(DimBuf.size()); + const uint8_t *DimLenBytes = reinterpret_cast(&DimBufLen); + Result.insert(Result.end(), DimLenBytes, DimLenBytes + 4); + + Result.insert(Result.end(), DimBuf.begin(), DimBuf.end()); + + std::vector DataBuf; + mx::eval(Arr); + + switch (Arr.dtype()) { + case mx::float16: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::float32: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::float64: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::uint8: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + DataBuf.insert(DataBuf.end(), Data, Data + ByteSize); + break; + } + case mx::int32: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::int64: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + default: + spdlog::error("[WASI-NN] MLX backend: Unsupported MLX dtype for conversion " + "to Rust Tensor"sv); + break; + } + + uint32_t DataBufLen = static_cast(DataBuf.size()); + const uint8_t *DataLenBytes = reinterpret_cast(&DataBufLen); + Result.insert(Result.end(), DataLenBytes, DataLenBytes + 4); + Result.insert(Result.end(), DataBuf.begin(), DataBuf.end()); + + return Result; +} + +enum AnswerSataus { + STOP, + WAIT, + GO, +}; + +AnswerSataus answerSataus(std::string Text, std::string End) { + if (endsWith(Text, End)) { + return STOP; + } + for (int Idx = 1; Idx < static_cast(End.size()); Idx++) { + if (endsWith(Text, End.substr(0, Idx))) { + return WAIT; + } + } + return GO; +} + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, WASINN::Device, + uint32_t &GraphId) noexcept { + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::MLX); + auto &GraphRef = Env.NNGraph[GId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: Load."sv); + } + std::string TokenizerPath; + // Parse metadata. + if (Builders.size() <= 1) { + spdlog::error( + "[WASI-NN] MLX backend: Lack model weight or required metadata (model_type)."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + const std::string Metadata = std::string( + reinterpret_cast(Builders.back().data()), Builders.back().size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] MLX backend: Parse metadata error"sv); + Env.deleteGraph(GId); + return ErrNo::InvalidEncoding; + } + if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { + std::string_view ModelType; + auto Err = Doc["model_type"].get().get(ModelType); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.ModelType = ModelType; + } else { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + if (Doc.at_key("enable_debug_log").error() == simdjson::SUCCESS) { + bool EnableDebugLog; + auto Err = Doc["enable_debug_log"].get().get(EnableDebugLog); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the enable_debug_log option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.EnableDebugLog = EnableDebugLog; + } + if (Doc.at_key("tokenizer").error() == simdjson::SUCCESS) { + std::string_view TokenizerPathView; + auto Err = Doc["tokenizer"].get().get(TokenizerPathView); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + TokenizerPath = TokenizerPathView; + } + if (Doc.at_key("max_token").error() == simdjson::SUCCESS) { + uint64_t MaxToken; + auto Err = Doc["max_token"].get().get(MaxToken); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the max_token option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.MaxToken = MaxToken; + } + if (Doc.at_key("q_bits").error() == simdjson::SUCCESS && + Doc.at_key("group_size").error() == simdjson::SUCCESS && + Doc.at_key("is_quantized").error() == simdjson::SUCCESS) { + uint64_t QBits; + uint64_t GroupSize; + bool IsQuantized; + auto ErrQBits = Doc["q_bits"].get().get(QBits); + auto ErrGroupSize = Doc["group_size"].get().get(GroupSize); + auto ErrIsQuantized = Doc["is_quantized"].get().get(IsQuantized); + if (ErrQBits || ErrGroupSize || ErrIsQuantized) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the q_bits or group_size option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.IsQuantized = IsQuantized; + GraphRef.QBits = QBits; + GraphRef.GroupSize = GroupSize; + } + if (Doc.at_key("quantization").error() == simdjson::SUCCESS) { + auto QuantResult = Doc["quantization"].get_object(); + auto Err = QuantResult.value()["group_size"].get().get( + GraphRef.GroupSize); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the group size from quantization option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + Err = QuantResult.value()["bits"].get().get(GraphRef.QBits); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the group size from quantization option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.IsQuantized = true; + } + + std::unordered_map Weights; + // Handle the model path. + for (size_t Idx = 0; Idx < Builders.size() - 1; Idx++) { + auto WeightData = Builders[Idx]; + const std::string BinModel(reinterpret_cast(WeightData.data()), + WeightData.size()); + spdlog::info("[WASI-NN] MLX BinModel: {}"sv, BinModel.size()); + if (BinModel.size() == 0) { + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + std::string ModelFilePath; + if (BinModel.substr(0, 8) == "preload:"sv) { + ModelFilePath = BinModel.substr(8); + } else { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] MLX backend: Model path not found in nn-preload, " + "write model into a tmpfile."sv); + } + // Write model to file. + // TODO: handle different model format. + ModelFilePath = "MLX" + std::to_string(Idx) + ".safetensors"; + WasmEdge::FStream::OFStream TempFile( + ModelFilePath, std::ios_base::out | std::ios_base::binary, + Env.getEnv()); + if (!TempFile) { + spdlog::error( + "[WASI-NN] MLX backend: Failed to create the temporary file. "sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + TempFile.write(BinModel.data(), BinModel.size()); + TempFile.close(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] MLX backend: Write model into a tmpfile...Done"sv); + } + } + + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + auto Weight = llamaToMlxllm(ModelFilePath); + Weights.insert(Weight.begin(), Weight.end()); + } else if (GraphRef.ModelType == "llama_3_8b") { + auto Weight = llamaToMlxllm(ModelFilePath); + Weights.insert(Weight.begin(), Weight.end()); + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + auto Weight = llamaToMlxllm(ModelFilePath); + Weights.insert(Weight.begin(), Weight.end()); + } else if (GraphRef.ModelType == "gemma3") { + auto Weight = mx::load_safetensors(ModelFilePath); + Weights.insert(Weight.first.begin(), Weight.first.end()); + } else if (GraphRef.ModelType == "whisper") { + auto Weight = mx::load_safetensors(ModelFilePath); + Weights.insert(Weight.first.begin(), Weight.first.end()); + } else { + spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, + GraphRef.ModelType); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + } + // Create Model. + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + GraphRef.Model = llm::tinyLlama11BChatV10(); + GraphRef.Prmopt = TinyLLaMAPrompt(); + GraphRef.ModelArch = "llm"; + } else if (GraphRef.ModelType == "llama_3_8b") { + GraphRef.Model = llm::llama38b(); + GraphRef.Prmopt = LLaMA3Prompt(); + GraphRef.ModelArch = "llm"; + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + GraphRef.Model = llm::llama27bChat(); + GraphRef.Prmopt = LLaMA2Prompt(); + GraphRef.ModelArch = "llm"; + } else if (GraphRef.ModelType == "gemma3") { + auto Obj = Doc.get_object(); + gemma3::ModelConfig ModelConfigObj = + gemma3::ModelConfig::fromDict(Obj.value()); + ModelConfigObj.VisionConfig = + (Obj.at_key("vision_config").error() == simdjson::SUCCESS) + ? gemma3::VisionConfig::fromDict( + Obj["vision_config"].get_object().value()) + : gemma3::VisionConfig(); + ModelConfigObj.TextConfig = + (Obj.at_key("text_config").error() == simdjson::SUCCESS) + ? gemma3::TextConfig::fromDict( + Obj["text_config"].get_object().value()) + : gemma3::TextConfig(); + GraphRef.Model = std::dynamic_pointer_cast( + std::make_shared(gemma3::Model(ModelConfigObj))); + Weights = std::dynamic_pointer_cast(GraphRef.Model) + ->sanitize(Weights); + Weights = + gemma3::VisionModel(ModelConfigObj.VisionConfig).sanitize(Weights); + GraphRef.ModelArch = "vlm"; + } else if (GraphRef.ModelType == "whisper") { + auto Obj = Doc.get_object(); + whisper::ModelDimensions DefaultDims = + whisper::ModelDimensions::fromDict(Obj.value()); + GraphRef.Model = std::dynamic_pointer_cast( + std::make_shared(DefaultDims)); + GraphRef.ModelArch = "whisper"; + } else { + spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, + GraphRef.ModelType); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + + // Load tokenizer. + if (!TokenizerPath.empty()) { + auto Bytes = loadBytesFromFile(TokenizerPath, Env.getEnv()); + if (Bytes.empty()) { + spdlog::error("[WASI-NN] MLX backend: Load tokenizer failed."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.Tok = tokenizers::Tokenizer::FromBlobJSON(Bytes); + } else if (GraphRef.ModelArch == "llm") { + spdlog::error("[WASI-NN] MLX backend: Tokenizer path not found."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + + if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && GraphRef.IsQuantized) { + spdlog::info( + "[WASI-NN] MLX backend: load Quantized model with q_bits: {} and group_size: {}"sv, + GraphRef.QBits, GraphRef.GroupSize); + GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits, "", + Weights); + } + + // Load weight. + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "llama_3_8b") { + GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "gemma3") { + GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "whisper") { + GraphRef.Model->update(Weights); + } else { + spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, + GraphRef.ModelType); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + + if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && !GraphRef.IsQuantized) { + spdlog::info( + "[WASI-NN] MLX backend: Quantize model with q_bits: {} and group_size: {}"sv, + GraphRef.QBits, GraphRef.GroupSize); + GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); + } + + GraphId = GId; + Env.NNGraph[GId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); + auto &GraphRef = Env.NNGraph[GraphId].get(); + auto &CxtRef = Env.NNContext[ContextId].get(); + if (GraphRef.ModelArch == "llm") { + CxtRef.Inputs = LLMInput(); + } else if (GraphRef.ModelArch == "vlm") { + CxtRef.Inputs = VLMInput(); + } else if (GraphRef.ModelArch == "whisper") { + CxtRef.Inputs = WhisperInput(); + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + Env.deleteContext(ContextId); + return ErrNo::InvalidArgument; + } + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: setInput"sv); + } + + if (GraphRef.ModelArch == "llm") { + std::get(CxtRef.Inputs).Prompt = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + } else if (GraphRef.ModelArch == "vlm") { + if (Index == 0) { + std::get(CxtRef.Inputs).Prompt = fromBytes(Tensor.Tensor); + } else if (Index == 1) { + std::get(CxtRef.Inputs).Pixel = fromBytes(Tensor.Tensor); + } else if (Index == 2) { + std::get(CxtRef.Inputs).Mask = fromBytes(Tensor.Tensor); + } else { + spdlog::error("[WASI-NN] MLX backend: Index out of range."sv); + return ErrNo::InvalidArgument; + } + } else if (GraphRef.ModelArch == "whisper") { + std::get(CxtRef.Inputs).Audio = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + return ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: getOutput"sv); + } + if (GraphRef.ModelArch == "llm") { + auto *Output = std::get_if(&CxtRef.Outputs); + if (Output != nullptr) { + std::copy_n(Output->Answer.data(), Output->Answer.length(), + OutBuffer.data()); + BytesWritten = Output->Answer.length(); + } else { + spdlog::error("[WASI-NN] MLX backend: No output found."sv); + return ErrNo::InvalidArgument; + } + } else if (GraphRef.ModelArch == "vlm") { + auto *Output = std::get_if(&CxtRef.Outputs); + if (Output != nullptr) { + auto OutputBytes = toBytes(Output->Answer); + std::copy_n(OutputBytes.data(), OutputBytes.size(), OutBuffer.data()); + BytesWritten = OutputBytes.size(); + } else { + spdlog::error("[WASI-NN] MLX backend: No output found."sv); + return ErrNo::InvalidArgument; + } + } else if (GraphRef.ModelArch == "whisper") { + auto *Output = std::get_if(&CxtRef.Outputs); + if (Output != nullptr) { + std::string Text = Output->Text; + std::copy_n(Text.data(), Text.length(), OutBuffer.data()); + BytesWritten = Text.length(); + } else { + spdlog::error("[WASI-NN] MLX backend: No output found."sv); + return ErrNo::InvalidArgument; + } + + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + return ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + const auto Start{std::chrono::steady_clock::now()}; + size_t TokenListSize = 0; + + if (GraphRef.ModelArch == "llm" && GraphRef.Tok == nullptr) { + spdlog::error("[WASI-NN] MLX backend: Tokenizer not loaded."sv); + return ErrNo::InvalidArgument; + } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: compute"sv); + } + if (GraphRef.ModelArch == "llm") { + auto Result = + std::dynamic_pointer_cast(GraphRef.Model) + ->generate(std::get(CxtRef.Inputs).Prompt, + GraphRef.Prmopt, GraphRef.MaxToken, false, GraphRef.Tok); + CxtRef.Outputs = LLMOutput({Result.Answer}); + TokenListSize = Result.TokenList.size(); + } else if (GraphRef.ModelArch == "vlm") { + auto &Input = std::get(CxtRef.Inputs); + std::map> + Kwargs; + Kwargs.insert({"input_ids", Input.Prompt}); + Kwargs.insert({"pixel_values", Input.Pixel}); + Kwargs.insert({"mask", Input.Mask}); + auto TokenList = std::dynamic_pointer_cast(GraphRef.Model) + ->generate({}, std::nullopt, false, Kwargs); + auto TokenArray = + mx::array(TokenList.data(), {static_cast(TokenList.size())}); + CxtRef.Outputs = VLMOutput({TokenArray}); + TokenListSize = TokenList.size(); + } else if (GraphRef.ModelArch == "whisper") { + CxtRef.Outputs = whisper::transcribe( + std::get(CxtRef.Inputs).Audio, + std::dynamic_pointer_cast(GraphRef.Model), false); + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + return ErrNo::InvalidArgument; + } + const auto End{std::chrono::steady_clock::now()}; + const std::chrono::duration ElapsedSeconds{End - Start}; + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: Generate {} tokens."sv, TokenListSize); + spdlog::info("Elapsed time: {} s. TPS: {}.", ElapsedSeconds.count(), + TokenListSize / ElapsedSeconds.count()); + } + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] MLX backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"MLX\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/wasinn_mlx.h b/plugins/wasi_nn/wasinn_mlx.h new file mode 100644 index 00000000..ab519bbf --- /dev/null +++ b/plugins/wasi_nn/wasinn_mlx.h @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +#include "MLX/mlx/base.h" +#include "MLX/mlx/transformer.h" +#include "MLX/model/llm/transformer.h" +#include "MLX/prompt/prompt.h" +#include + +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::MLX { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +struct LLMInput { + std::string Prompt = {}; +}; +struct LLMOutput { + std::string Answer = {}; +}; +struct VLMInput { + mx::array Prompt = mx::array({}); + mx::array Pixel = mx::array({}); + mx::array Mask = mx::array({}); +}; +struct VLMOutput { + mx::array Answer = mx::array({}); +}; +struct WhisperInput { + std::string Audio; +}; +struct Graph { + std::string ModelType; + std::string ModelArch; + std::unique_ptr Tok = nullptr; + std::shared_ptr Model; + double Temp = 0.0; + bool EnableDebugLog = false; + bool IsQuantized = false; + uint64_t MaxToken = 1024; + uint64_t QBits = 0; + uint64_t GroupSize = 0; + BasePrompt Prmopt; +}; +struct Context { + Context(uint32_t Gid, Graph &) noexcept : GraphId(Gid) {} + uint32_t GraphId; + std::variant Inputs; + std::variant Outputs; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/wasinn_neuralspeed.cpp b/plugins/wasi_nn/wasinn_neuralspeed.cpp new file mode 100644 index 00000000..cc36ea09 --- /dev/null +++ b/plugins/wasi_nn/wasinn_neuralspeed.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_neuralspeed.h" +#include "wasinnenv.h" + +namespace WasmEdge::Host::WASINN::NeuralSpeed { +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Neural Speed backend is removed due to the upstream " + "end-of-life. Reference: " + "https://github.com/intel/neural-speed"sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +} // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/wasinn_neuralspeed.h b/plugins/wasi_nn/wasinn_neuralspeed.h new file mode 100644 index 00000000..974f3dc8 --- /dev/null +++ b/plugins/wasi_nn/wasinn_neuralspeed.h @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::NeuralSpeed { +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/wasinn_onnx.cpp b/plugins/wasi_nn/wasinn_onnx.cpp new file mode 100644 index 00000000..744cc263 --- /dev/null +++ b/plugins/wasi_nn/wasinn_onnx.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_onnx.h" +#include "wasinnenv.h" + +using namespace std::literals; + +namespace WasmEdge::Host::WASINN::ONNX { +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] ONNX backend is not supported."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +} // namespace WasmEdge::Host::WASINN::ONNX diff --git a/plugins/wasi_nn/wasinn_onnx.h b/plugins/wasi_nn/wasinn_onnx.h new file mode 100644 index 00000000..b546e1d9 --- /dev/null +++ b/plugins/wasi_nn/wasinn_onnx.h @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::ONNX { +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::ONNX diff --git a/plugins/wasi_nn/wasinn_openvino.cpp b/plugins/wasi_nn/wasinn_openvino.cpp new file mode 100644 index 00000000..ab5ee6cd --- /dev/null +++ b/plugins/wasi_nn/wasinn_openvino.cpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_openvino.h" +#include "wasinnenv.h" + +#include + +using namespace std::literals; + +namespace WasmEdge::Host::WASINN::OpenVINO { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO + +static Expect +GetDeviceString(WASINN::Device TargetDevice, + std::string &DeviceString) noexcept { + switch (TargetDevice) { + case Device::AUTO: + case Device::CPU: + DeviceString = "CPU"; + break; + case Device::GPU: + DeviceString = "GPU"; + break; + default: + spdlog::error("[WASI-NN] Unsupported device type"sv); + return WASINN::ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept { + // The graph builder length must be 2. + if (Builders.size() != 2) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2"sv, + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the XML and Weight raw buffer. + // Builder-0: the XML string + // Builder-1: the Weight binary + auto XML = Builders[0]; + auto Weight = Builders[1]; + + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::OpenVINO); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Store device information + GraphRef.TargetDevice = Device; + + try { + std::string ModelString(reinterpret_cast(XML.data()), + XML.size()); + GraphRef.OpenVINOIWeightTensor = + ov::Tensor(ov::element::Type_t::u8, {Weight.size()}); + std::copy_n(Weight.data(), Weight.size(), + static_cast(GraphRef.OpenVINOIWeightTensor.data())); + GraphRef.OpenVINOModel = Env.OpenVINOCore.read_model( + ModelString, GraphRef.OpenVINOIWeightTensor); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Model Load Exception: {}"sv, EX.what()); + Env.deleteGraph(GId); + return WASINN::ErrNo::RuntimeError; + } + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Check the network and the execution network with the graph ID. + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.OpenVINOModel == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is empty!"sv, GraphId); + return WASINN::ErrNo::MissingMemory; + } + // Create context. + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + if (GraphRef.OpenVINOModel == nullptr) { + spdlog::error("[WASI-NN] The founded openvino session is empty"sv); + return WASINN::ErrNo::MissingMemory; + } + + if (Tensor.Dimension.size() > 8) { + spdlog::error("[WASI-NN] Tensor dimension is out of range, expect it under " + "8-dim, but got {}-dim."sv, + Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + if (Tensor.RType != WASINN::TensorType::F32) { + spdlog::error( + "[WASI-NN] Only F32 inputs and outputs are supported for now."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Check the input index. + if (GraphRef.OpenVINOModel->inputs().size() <= Index) { + spdlog::error( + "[WASI-NN] The input index {} exceeds the inputs number {}."sv, Index, + GraphRef.OpenVINOModel->inputs().size()); + return WASINN::ErrNo::InvalidArgument; + } + + try { + ov::element::Type InputType = ov::element::f32; + ov::Shape InputShape(Tensor.Dimension.data(), + Tensor.Dimension.data() + Tensor.Dimension.size()); + ov::Tensor InputTensor = + ov::Tensor(InputType, InputShape, Tensor.Tensor.data()); + std::string Device; + if (!GetDeviceString(GraphRef.TargetDevice, Device)) { + spdlog::error("[WASI-NN] Failed to get device string for OpenVINO."sv); + return WASINN::ErrNo::InvalidArgument; + } + + ov::CompiledModel CompiledModel = + Env.OpenVINOCore.compile_model(GraphRef.OpenVINOModel, Device); + CxtRef.OpenVINOInferRequest = CompiledModel.create_infer_request(); + CxtRef.OpenVINOInferRequest.set_input_tensor(Index, InputTensor); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Set Input Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + // Check the output index. + if (GraphRef.OpenVINOModel->outputs().size() <= Index) { + spdlog::error( + "[WASI-NN] The output index {} exceeds the outputs number {}."sv, Index, + GraphRef.OpenVINOModel->outputs().size()); + return WASINN::ErrNo::InvalidArgument; + } + + try { + const ov::Tensor &OutputTensor = + CxtRef.OpenVINOInferRequest.get_output_tensor(Index); + BytesWritten = OutputTensor.get_byte_size(); + std::copy_n(static_cast(OutputTensor.data()), BytesWritten, + OutBuffer.data()); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Get Output Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + try { + CxtRef.OpenVINOInferRequest.infer(); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Infer Request Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] OpenVINO backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::OpenVINO diff --git a/plugins/wasi_nn/wasinn_openvino.h b/plugins/wasi_nn/wasinn_openvino.h new file mode 100644 index 00000000..f7472ae9 --- /dev/null +++ b/plugins/wasi_nn/wasinn_openvino.h @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO +#include "openvino/openvino.hpp" +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::OpenVINO { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO +struct Graph { + ~Graph() noexcept {} + ov::Tensor OpenVINOIWeightTensor; + std::shared_ptr OpenVINOModel; + Device TargetDevice = Device::AUTO; +}; + +struct Context { + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + ~Context() noexcept {} + uint32_t GraphId; + ov::InferRequest OpenVINOInferRequest; +}; + +struct Environ { + Environ() noexcept {} + ~Environ() noexcept {} + ov::Core OpenVINOCore; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +struct Environ {}; +#endif + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::OpenVINO diff --git a/plugins/wasi_nn/wasinn_openvino_genai.cpp b/plugins/wasi_nn/wasinn_openvino_genai.cpp new file mode 100644 index 00000000..a7053ec9 --- /dev/null +++ b/plugins/wasi_nn/wasinn_openvino_genai.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_openvino_genai.h" +#include "wasinnenv.h" + +#include + +using namespace std::literals; + +namespace WasmEdge::Host::WASINN::OpenVINOGenAI { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINOGENAI + +Expect GetDeviceString(WASINN::Device TargetDevice, + std::string &DeviceString) noexcept { + switch (TargetDevice) { + case Device::CPU: + DeviceString = "CPU"; + break; + case Device::GPU: + DeviceString = "GPU"; + break; + default: + spdlog::error("[WASI-NN] Unsupported device type"sv); + return WASINN::ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + +Expect isStringTensor(const TensorData &Tensor) noexcept { + if (Tensor.RType != WASINN::TensorType::U8) { + spdlog::warn( + "[WASI-NN] Only STRING (u8) inputs and outputs are supported for " + "now. Input Type: {}"sv, + Tensor.RType); + // return WASINN::ErrNo::InvalidArgument; + } + if (Tensor.Dimension.size() != 1) { + spdlog::error("[WASI-NN] Tensor dimension is out of range, expect it under " + "1-dim, but got {}-dim."sv, + Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + +Expect +LLMPipelineBackend::SetContextInput(Context &CxtRef, uint32_t Index, + const TensorData &Tensor) { + + if (Index != 0) { + spdlog::error("[WASI-NN] The input index {} is out of range."sv, Index); + return WASINN::ErrNo::InvalidArgument; + } + + if (auto Res = isStringTensor(Tensor); !Res) { + return Res; + } + + try { + CxtRef.StringInput = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Set Input Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect LLMPipelineBackend::Generate(Context &CxtRef) { + try { + // TODO: let the user to set the generation config. + spdlog::warn("[WASI-NN] The generation config is not supported for now."sv); + spdlog::warn("[WASI-NN] Maximum token limit is set to 100."sv); + ov::genai::GenerationConfig config; + config.max_new_tokens = 100; + CxtRef.StringOutput = Model->generate(CxtRef.StringInput, config); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Generate Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect +LLMPipelineBackend::GetContextOutput(Context &CxtRef, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) { + if (Index != 0) { + spdlog::error("[WASI-NN] The output index {} is out of range."sv, Index); + return WASINN::ErrNo::InvalidArgument; + } + + try { + BytesWritten = CxtRef.StringOutput.size(); + std::copy_n(reinterpret_cast(CxtRef.StringOutput.data()), + BytesWritten, OutBuffer.data()); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Get Output Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept { + // The graph builder length must be 3. + if (Builders.size() != 3) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 3"sv, + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the XML and Weight raw buffer. + // Builder-0: Reserved, the string "LLMPipeline" + // Builder-1: Path to the dir model xml/bin files + // Builder-2: Empty for now (reserved for future use) + + // There are 4 types (text or img) x (text or img); we assume the input is 0 + // for now. + auto ModelType = std::string( + reinterpret_cast(Builders[0].data()), Builders[0].size()); + auto ModelPath = std::string( + reinterpret_cast(Builders[1].data()), Builders[1].size()); + // TODO: Support extra model information. (ex. enable kv cache) + [[maybe_unused]] auto ModelExtra = std::string( + reinterpret_cast(Builders[2].data()), Builders[2].size()); + + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::OpenVINOGenAI); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Store device information + GraphRef.TargetDevice = Device; + std::string DeviceString; + if (auto Err = GetDeviceString(Device, DeviceString); + Err != WASINN::ErrNo::Success) { + return Err; + } + + try { + // Create the OpenVINO GenAI Backend. + // Currently, we only support LLMPipeline. + if (ModelType == "LLMPipeline") { + GraphRef.OpenVINOGenAI = + std::make_shared(ModelPath, DeviceString); + } else { + spdlog::error("[WASI-NN] Unsupported model type: {}"sv, ModelType); + return WASINN::ErrNo::InvalidArgument; + } + + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Model Load Exception: {}"sv, EX.what()); + Env.deleteGraph(GId); + return WASINN::ErrNo::RuntimeError; + } + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Check the network and the execution network with the graph ID. + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.OpenVINOGenAI == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is empty!"sv, GraphId); + return WASINN::ErrNo::MissingMemory; + } + // Create context. + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + if (GraphRef.OpenVINOGenAI == nullptr) { + spdlog::error("[WASI-NN] The founded openvino genei session is empty"sv); + return WASINN::ErrNo::MissingMemory; + } + + return GraphRef.OpenVINOGenAI->SetContextInput(CxtRef, Index, Tensor); +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + if (GraphRef.OpenVINOGenAI == nullptr) { + spdlog::error("[WASI-NN] The founded openvino genei session is empty"sv); + return WASINN::ErrNo::MissingMemory; + } + + return GraphRef.OpenVINOGenAI->GetContextOutput(CxtRef, Index, OutBuffer, + BytesWritten); +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + try { + GraphRef.OpenVINOGenAI->Generate(CxtRef); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Infer Request Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error( + "[WASI-NN] OpenVINO GenAI backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINOGenAI\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::OpenVINOGenAI diff --git a/plugins/wasi_nn/wasinn_openvino_genai.h b/plugins/wasi_nn/wasinn_openvino_genai.h new file mode 100644 index 00000000..cf6c673d --- /dev/null +++ b/plugins/wasi_nn/wasinn_openvino_genai.h @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINOGENAI +#include "openvino/openvino.hpp" +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::OpenVINOGenAI { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINOGENAI + +struct Context; +class OpenVINOGenAIBackend { +public: + OpenVINOGenAIBackend() = default; + virtual Expect SetContextInput(Context &CxtRef, uint32_t Index, + const TensorData &Tensor) = 0; + virtual Expect Generate(Context &CxtRef) = 0; + virtual Expect GetContextOutput(Context &CxtRef, + uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) = 0; + virtual ~OpenVINOGenAIBackend() noexcept {} +}; + +class LLMPipelineBackend : public OpenVINOGenAIBackend { +public: + LLMPipelineBackend(std::string Path, std::string Device) { + Model = std::make_shared(Path, Device); + } + ~LLMPipelineBackend() noexcept {} + Expect SetContextInput(Context &CxtRef, uint32_t Index, + const TensorData &Tensor) override; + Expect Generate(Context &CxtRef) override; + Expect GetContextOutput(Context &CxtRef, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) override; + +private: + std::shared_ptr Model; +}; + +struct Graph { + ~Graph() noexcept {} + std::shared_ptr OpenVINOGenAI; + Device TargetDevice = Device::AUTO; +}; + +struct Context { + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + ~Context() noexcept {} + uint32_t GraphId; + std::string StringInput; + std::string StringOutput; + + // For image input/output + // ov::Tensor TensorInput; + // ov::Tensor TensorOutput; +}; + +struct Environ { + Environ() noexcept {} + ~Environ() noexcept {} + ov::Core OpenVINOCore; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +struct Environ {}; +#endif + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::OpenVINOGenAI diff --git a/plugins/wasi_nn/wasinn_piper.cpp b/plugins/wasi_nn/wasinn_piper.cpp new file mode 100644 index 00000000..7e71efc9 --- /dev/null +++ b/plugins/wasi_nn/wasinn_piper.cpp @@ -0,0 +1,517 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_piper.h" +#include "common/errcode.h" +#include "common/span.h" +#include "wasinnenv.h" +#include "wasinntypes.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +#include "simdjson.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::Piper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER + +namespace { + +// Helper function to write the WAV header. +void writeWavHeader(int SampleRate, int16_t NumChannels, int32_t NumSamples, + std::vector &OutputBuffer) { + int32_t const ByteRate = SampleRate * NumChannels * sizeof(int16_t); + int32_t const DataSize = NumSamples * NumChannels * sizeof(int16_t); + int32_t const RiffSize = 36 + DataSize; + + auto PushU32 = [&](int32_t Val) { + OutputBuffer.push_back(Val & 0xFF); + OutputBuffer.push_back((Val >> 8) & 0xFF); + OutputBuffer.push_back((Val >> 16) & 0xFF); + OutputBuffer.push_back((Val >> 24) & 0xFF); + }; + auto PushU16 = [&](int16_t Val) { + OutputBuffer.push_back(Val & 0xFF); + OutputBuffer.push_back((Val >> 8) & 0xFF); + }; + auto PushStr = [&](const char *Str) { + for (int I = 0; I < 4; I++) + OutputBuffer.push_back(Str[I]); + }; + + PushStr("RIFF"); + PushU32(RiffSize); + PushStr("WAVE"); + PushStr("fmt "); + PushU32(16); + PushU16(1); + PushU16(NumChannels); + PushU32(SampleRate); + PushU32(ByteRate); + PushU16(NumChannels * sizeof(int16_t)); + PushU16(16); + PushStr("data"); + PushU32(DataSize); +} + +} // namespace +template +std::tuple getOption(simdjson::dom::object &Object, + std::string_view Key, T &Result) { + if (auto Error = Object[Key].get(Result)) { + if (Error == simdjson::error_code::NO_SUCH_FIELD) { + return {WASINN::ErrNo::Success, false}; + } + spdlog::error( + "[WASI-NN] Piper backend: Unable to retrieve the \"{}\" option: {}"sv, + Key, simdjson::error_message(Error)); + return {WASINN::ErrNo::InvalidArgument, false}; + } + return {WASINN::ErrNo::Success, true}; +} + +template +WASINN::ErrNo getOptionalOption(simdjson::dom::object &Object, + std::string_view Key, + std::optional &Result) { + auto Value = U{}; + auto [Err, HasValue] = getOption(Object, Key, Value); + if (HasValue) { + Result = Value; + } + return Err; +} + +WASINN::ErrNo parseSynthesisConfig(SynthesisConfig &SynthesisConfig, + simdjson::dom::object &Object) { + { + auto Value = std::optional{}; + if (auto Err = getOptionalOption(Object, "output_type", Value); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (Value) { + if (Value.value() == "wav") { + SynthesisConfig.OutputType = SynthesisConfigOutputType::OUTPUT_WAV; + } else if (Value.value() == "raw") { + SynthesisConfig.OutputType = SynthesisConfigOutputType::OUTPUT_RAW; + } else { + spdlog::error( + "[WASI-NN] Piper backend: The output_type option has an unknown value {}."sv, + Value.value()); + return WASINN::ErrNo::InvalidArgument; + } + } + } + { + auto SpeakerId = std::optional{}; + if (auto Err = getOptionalOption(Object, "speaker_id", SpeakerId); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (SpeakerId.has_value()) { + SynthesisConfig.SpeakerId = static_cast(SpeakerId.value()); + } + } + if (auto Err = getOptionalOption(Object, "noise_scale", + SynthesisConfig.NoiseScale); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (auto Err = getOptionalOption(Object, "length_scale", + SynthesisConfig.LengthScale); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (auto Err = getOptionalOption(Object, "noise_w", + SynthesisConfig.NoiseW); + Err != WASINN::ErrNo::Success) { + return Err; + } + return WASINN::ErrNo::Success; +} +WASINN::ErrNo parseRunConfig(RunConfig &RunConfig, + const std::string &String) noexcept { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + if (auto Error = Parser.parse(String).get(Doc)) { + spdlog::error("[WASI-NN] Piper backend: Parse run config error: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidEncoding; + } + simdjson::dom::object Object; + if (auto Error = Doc.get(Object)) { + spdlog::error( + "[WASI-NN] Piper backend: The run config is not an object: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidArgument; + } + + auto ModelPath = std::optional{}; + if (auto Err = getOptionalOption(Object, "model", ModelPath); + Err != WASINN::ErrNo::Success) { + return Err; + } + // Verify model file exists + if (ModelPath) { + auto Path = std::filesystem::u8path(ModelPath.value()); + if (!std::filesystem::exists(Path)) { + spdlog::error("[WASI-NN] Piper backend: Model file doesn't exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + RunConfig.ModelPath = Path; + } else { + spdlog::error( + "[WASI-NN] Piper backend: The model option is required but not provided"sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto ModelConfigPath = std::optional{}; + if (auto Err = getOptionalOption(Object, "config", ModelConfigPath); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (ModelConfigPath) { + RunConfig.ModelConfigPath = + std::filesystem::u8path(ModelConfigPath.value()); + } else { + RunConfig.ModelConfigPath = RunConfig.ModelPath; + RunConfig.ModelConfigPath += ".json"; + } + // Verify model config exists + if (!std::filesystem::exists(RunConfig.ModelConfigPath)) { + spdlog::error("[WASI-NN] Piper backend: Model config doesn't exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + + if (auto Err = parseSynthesisConfig(RunConfig.DefaultSynthesisConfig, Object); + Err != WASINN::ErrNo::Success) { + return Err; + } + { + auto Path = std::optional{}; + if (auto Err = getOptionalOption(Object, "espeak_data", Path); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (Path) { + RunConfig.ESpeakDataPath = std::filesystem::u8path(Path.value()); + } + } + if (auto Err = + std::get<0>(getOption(Object, "json_input", RunConfig.JsonInput)); + Err != WASINN::ErrNo::Success) { + return Err; + } + return WASINN::ErrNo::Success; +} + +void updatePiperOptions(const SynthesisConfig &SynthesisConfig, + piper_synthesize_options &Options) { + if (SynthesisConfig.SpeakerId) { + Options.speaker_id = SynthesisConfig.SpeakerId.value(); + } + if (SynthesisConfig.NoiseScale) { + Options.noise_scale = SynthesisConfig.NoiseScale.value(); + } + if (SynthesisConfig.LengthScale) { + Options.length_scale = SynthesisConfig.LengthScale.value(); + } + if (SynthesisConfig.NoiseW) { + Options.noise_w_scale = SynthesisConfig.NoiseW.value(); + } +} + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, WASINN::Device, + uint32_t &GraphId) noexcept { + // The graph builder length must be 1. + if (Builders.size() != 1) { + spdlog::error( + "[WASI-NN] Piper backend: Wrong GraphBuilder Length {:d}, expect 1"sv, + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + + // Add a new graph. + uint32_t const GId = Env.newGraph(Backend::Piper); + auto &GraphRef = Env.NNGraph[GId].get(); + GraphRef.Config = std::make_unique(); + auto String = std::string{Builders[0].begin(), Builders[0].end()}; + if (auto Res = parseRunConfig(*GraphRef.Config, String); + Res != WASINN::ErrNo::Success) { + Env.deleteGraph(GId); + spdlog::error("[WASI-NN] Piper backend: Failed to parse run config."sv); + return Res; + } + + std::string EspeakPath = ""; + if (GraphRef.Config->ESpeakDataPath) { + EspeakPath = GraphRef.Config->ESpeakDataPath->string(); + } + + piper_synthesizer *Synth = + piper_create(GraphRef.Config->ModelPath.string().c_str(), + GraphRef.Config->ModelConfigPath.string().c_str(), + EspeakPath.empty() ? nullptr : EspeakPath.c_str()); + + if (!Synth) { + spdlog::error( + "[WASI-NN] Piper backend: Failed to create piper synthesizer."sv); + Env.deleteGraph(GId); + return WASINN::ErrNo::InvalidArgument; + } + + GraphRef.Synth = std::unique_ptr(Synth); + GraphId = GId; + Env.NNGraph[GId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Create context. + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept { + if (Index != 0) { + spdlog::error("[WASI-NN] Piper backend: Input index must be 0."sv); + return WASINN::ErrNo::InvalidArgument; + } + if (Tensor.Dimension.size() != 1) { + spdlog::error( + "[WASI-NN] Piper backend: Input tensor dimension must be 1D."sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + CxtRef.Line = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + + if (GraphRef.Config->JsonInput) { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + simdjson::padded_string const PaddedInput(CxtRef.Line.value()); + + if (Parser.parse(PaddedInput).get(Doc) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: Failed to parse JSON input."sv); + return WASINN::ErrNo::InvalidArgument; + } + + simdjson::dom::object JsonObj; + if (Doc.get(JsonObj) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: JSON input is not an object."sv); + return WASINN::ErrNo::InvalidArgument; + } + + SynthesisConfig NewConfig; + if (auto Err = parseSynthesisConfig(NewConfig, JsonObj); + Err != WASINN::ErrNo::Success) { + return Err; + } + CxtRef.JsonInputSynthesisConfig = + std::make_unique>(NewConfig); + } + return WASINN::ErrNo::Success; +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + if (Index != 0) { + spdlog::error("[WASI-NN] Piper backend: Output index must be 0."sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto &CxtRef = Env.NNContext[ContextId].get(); + + if (!CxtRef.Output) { + spdlog::error("[WASI-NN] Piper backend: No output available."sv); + return WASINN::ErrNo::InvalidArgument; + } + + if (CxtRef.Output->size() >= std::numeric_limits::max()) { + spdlog::error( + "[WASI-NN] Piper backend: Output size {} is greater than std::numeric_limits::max() {}."sv, + CxtRef.Output->size(), std::numeric_limits::max()); + return WASINN::ErrNo::InvalidArgument; + } + + if (CxtRef.Output->size() > OutBuffer.size_bytes()) { + spdlog::error( + "[WASI-NN] Piper backend: Output size {} is greater than buffer size {}."sv, + CxtRef.Output->size(), OutBuffer.size_bytes()); + return WASINN::ErrNo::InvalidArgument; + } + + std::memcpy(OutBuffer.data(), CxtRef.Output->data(), CxtRef.Output->size()); + BytesWritten = CxtRef.Output->size(); + return WASINN::ErrNo::Success; +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + if (!CxtRef.Line) { + spdlog::error("[WASI-NN] Piper backend: Input is not set."sv); + return WASINN::ErrNo::InvalidArgument; + } + + std::string TextToSpeak = CxtRef.Line.value(); + + if (GraphRef.Config->JsonInput) { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + simdjson::padded_string const PaddedInput(TextToSpeak); + + if (Parser.parse(PaddedInput).get(Doc) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: Failed to parse JSON input."sv); + return WASINN::ErrNo::InvalidArgument; + } + + simdjson::dom::object JsonObj; + if (Doc.get(JsonObj) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: JSON input is not an object."sv); + return WASINN::ErrNo::InvalidArgument; + } + + std::string_view TmpText; + if (JsonObj["text"].get(TmpText) != simdjson::SUCCESS) { + spdlog::error( + "[WASI-NN] Piper backend: JSON input must contain 'text' field."sv); + return WASINN::ErrNo::InvalidArgument; + } + TextToSpeak = std::string(TmpText); + + SynthesisConfig NewConfig; + if (parseSynthesisConfig(NewConfig, JsonObj) == WASINN::ErrNo::Success) { + CxtRef.JsonInputSynthesisConfig = + std::make_unique>(NewConfig); + } + } + + piper_synthesize_options Options = + piper_default_synthesize_options(GraphRef.Synth.get()); + updatePiperOptions(GraphRef.Config->DefaultSynthesisConfig, Options); + + auto OutputType = SynthesisConfigOutputType::OUTPUT_WAV; + if (GraphRef.Config->DefaultSynthesisConfig.OutputType) { + OutputType = GraphRef.Config->DefaultSynthesisConfig.OutputType.value(); + } + + if (CxtRef.JsonInputSynthesisConfig && + CxtRef.JsonInputSynthesisConfig->has_value()) { + updatePiperOptions(CxtRef.JsonInputSynthesisConfig->value(), Options); + if (CxtRef.JsonInputSynthesisConfig->value().OutputType) { + OutputType = CxtRef.JsonInputSynthesisConfig->value().OutputType.value(); + } + } + + int const Res = piper_synthesize_start(GraphRef.Synth.get(), + TextToSpeak.c_str(), &Options); + if (Res != PIPER_OK) { + spdlog::error("[WASI-NN] Piper backend: piper_synthesize_start failed."sv); + return WASINN::ErrNo::RuntimeError; + } + + std::vector AudioBuffer; + piper_audio_chunk Chunk; + int SampleRate = 0; + constexpr float MaxWavValue = 32767.0f; + + while (piper_synthesize_next(GraphRef.Synth.get(), &Chunk) != PIPER_DONE) { + if (Chunk.num_samples == 0) { + continue; + } + SampleRate = Chunk.sample_rate; + size_t OriginalSize = AudioBuffer.size(); + AudioBuffer.resize(OriginalSize + Chunk.num_samples); + + for (size_t I = 0; I < Chunk.num_samples; I++) { + float Sample = Chunk.samples[I]; + Sample = std::clamp(Sample, -1.0f, 1.0f); + AudioBuffer[OriginalSize + I] = + static_cast(Sample * MaxWavValue); + } + } + + CxtRef.Output.emplace(); + constexpr int DefaultPiperSampleRate = 22050; + if (OutputType == SynthesisConfigOutputType::OUTPUT_WAV) { + if (SampleRate == 0) { + SampleRate = DefaultPiperSampleRate; + } + size_t const TotalSize = 44 + (AudioBuffer.size() * sizeof(int16_t)); + CxtRef.Output->reserve(TotalSize); + + writeWavHeader(SampleRate, 1, static_cast(AudioBuffer.size()), + *CxtRef.Output); + + const uint8_t *RawData = + reinterpret_cast(AudioBuffer.data()); + CxtRef.Output->insert(CxtRef.Output->end(), RawData, + RawData + (AudioBuffer.size() * sizeof(int16_t))); + + } else { + size_t const TotalSize = AudioBuffer.size() * sizeof(int16_t); + CxtRef.Output->resize(TotalSize); + const uint8_t *RawData = + reinterpret_cast(AudioBuffer.data()); + std::copy_n(RawData, TotalSize, CxtRef.Output->data()); + } + + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Piper backend is not supported."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::Piper diff --git a/plugins/wasi_nn/wasinn_piper.h b/plugins/wasi_nn/wasinn_piper.h new file mode 100644 index 00000000..e10e3150 --- /dev/null +++ b/plugins/wasi_nn/wasinn_piper.h @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "wasinntypes.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +#include + +#include +#include +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} // namespace WasmEdge::Host::WASINN + +namespace WasmEdge::Host::WASINN::Piper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +enum class SynthesisConfigOutputType { OUTPUT_WAV, OUTPUT_RAW }; +struct SynthesisConfig { + // Type of output to produce. + // Default is a WAV file. + std::optional OutputType; + + // Numerical id of the default speaker (multi-speaker voices) + std::optional SpeakerId; + + // Amount of noise to add during audio generation + std::optional NoiseScale; + + // Speech speed (1 = normal, < 1 is faster, > 1 is slower) + std::optional LengthScale; + + // Variation in phoneme lengths + std::optional NoiseW; + + // NOTE: Phoneme/Sentence silence configuration is not exposed + // in the new upstream C API (piper.h) and has been removed. +}; + +struct RunConfig { + // Path to .onnx voice file + std::filesystem::path ModelPath; + + // Path to JSON voice config file + std::filesystem::path ModelConfigPath; + + // Path to espeak-ng data directory + std::optional ESpeakDataPath; + + // input is JSON with format: + // { + // "text": str, (required) + // "speaker_id": int, (optional) + // "output_type": str, (optional, "wav" or "raw") + // } + // including options in SynthesisConfig + bool JsonInput = false; + + SynthesisConfig DefaultSynthesisConfig; +}; + +// Custom deleter for the piper_synthesizer +struct PiperDeleter { + void operator()(piper_synthesizer *P) const { + if (P) + piper_free(P); + } +}; + +struct Graph { + std::unique_ptr Config; + std::unique_ptr Synth; +}; +struct Context { + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + uint32_t GraphId; + std::optional Line; + std::unique_ptr> JsonInputSynthesisConfig; + std::optional> Output; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::Piper diff --git a/plugins/wasi_nn/wasinn_tf.cpp b/plugins/wasi_nn/wasinn_tf.cpp new file mode 100644 index 00000000..caf4492b --- /dev/null +++ b/plugins/wasi_nn/wasinn_tf.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_tf.h" +#include "wasinnenv.h" + +using namespace std::literals; + +namespace WasmEdge::Host::WASINN::Tensorflow { +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Tensorflow backend is not supported."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +} // namespace WasmEdge::Host::WASINN::Tensorflow diff --git a/plugins/wasi_nn/wasinn_tf.h b/plugins/wasi_nn/wasinn_tf.h new file mode 100644 index 00000000..c87329cd --- /dev/null +++ b/plugins/wasi_nn/wasinn_tf.h @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::Tensorflow { +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::Tensorflow diff --git a/plugins/wasi_nn/wasinn_tfl.cpp b/plugins/wasi_nn/wasinn_tfl.cpp new file mode 100644 index 00000000..51aab788 --- /dev/null +++ b/plugins/wasi_nn/wasinn_tfl.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_tfl.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +#include "tensorflow/lite/c/common.h" +#endif + +using namespace std::literals; + +namespace WasmEdge::Host::WASINN::TensorflowLite { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept { + if ((Device != WASINN::Device::CPU)) { + spdlog::error("[WASI-NN] TensorflowLite Only support CPU target."sv); + return WASINN::ErrNo::InvalidArgument; + } + // The graph builder length must be 1. + if (Builders.size() != 1) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1"sv, + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::TensorflowLite); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Copy graph builder data to TfLiteModData and create a new TfLiteModel. + GraphRef.TfLiteModData.assign(Builders[0].begin(), Builders[0].end()); + GraphRef.TFLiteMod = TfLiteModelCreate(GraphRef.TfLiteModData.data(), + GraphRef.TfLiteModData.size()); + if (unlikely(GraphRef.TFLiteMod == nullptr)) { + spdlog::error("[WASI-NN] Cannot import TFLite model"sv); + Env.deleteGraph(GId); + return WASINN::ErrNo::InvalidArgument; + } + + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Check the network and the execution network with the graph ID. + if (Env.NNGraph[GraphId].get().TFLiteMod == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is missing!"sv, GraphId); + return WASINN::ErrNo::MissingMemory; + } + + // Create context. + uint32_t CId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + auto &CxtRef = Env.NNContext[CId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + auto *TFLiteOps = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsSetNumThreads(TFLiteOps, 2); + CxtRef.TFLiteInterp = TfLiteInterpreterCreate(GraphRef.TFLiteMod, TFLiteOps); + TfLiteInterpreterOptionsDelete(TFLiteOps); + if (unlikely(CxtRef.TFLiteInterp == nullptr)) { + spdlog::error("[WASI-NN] Cannot create TFLite interpreter."sv); + Env.deleteContext(CId); + return WASINN::ErrNo::Busy; + } + TfLiteInterpreterAllocateTensors(CxtRef.TFLiteInterp); + + ContextId = CId; + Env.NNContext[ContextId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const WASINN::TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + uint32_t InCnt = TfLiteInterpreterGetInputTensorCount(CxtRef.TFLiteInterp); + if (Index >= InCnt) { + spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " + "inputs are allowed", + Index, InCnt); + return WASINN::ErrNo::InvalidArgument; + } + + auto *HoldTensor = + TfLiteInterpreterGetInputTensor(CxtRef.TFLiteInterp, Index); + // Check the input data size. + const auto HoldTensorByteSize = TfLiteTensorByteSize(HoldTensor); + if (HoldTensorByteSize != Tensor.Tensor.size()) { + spdlog::error("[WASI-NN] Expect tensor byte size {}, but got {}"sv, + HoldTensorByteSize, Tensor.Tensor.size()); + return WASINN::ErrNo::InvalidArgument; + } + // Check the input tensor dimensions. + const auto HoldTensorNumDims = TfLiteTensorNumDims(HoldTensor); + if (static_cast(HoldTensorNumDims) != Tensor.Dimension.size()) { + spdlog::error( + "[WASI-NN] Expect tensor number of dimensions {}, but got {}"sv, + HoldTensorNumDims, Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + for (uint32_t I = 0; I < Tensor.Dimension.size(); I++) { + const auto HoldTensorDim = TfLiteTensorDim(HoldTensor, I); + if (static_cast(HoldTensorDim) != Tensor.Dimension[I]) { + spdlog::error("[WASI-NN] Expect tensor dimension[{}] = {}, but got {}"sv, + I, HoldTensorDim, Tensor.Dimension[I]); + return WASINN::ErrNo::InvalidArgument; + } + } + // Check the input tensor type. + WASINN::TensorType LiteType; + switch (const auto Type = TfLiteTensorType(HoldTensor)) { + case TfLiteType::kTfLiteUInt8: + LiteType = WASINN::TensorType::U8; + break; + case TfLiteType::kTfLiteFloat16: + LiteType = WASINN::TensorType::F16; + break; + case TfLiteType::kTfLiteFloat32: + LiteType = WASINN::TensorType::F32; + break; + case TfLiteType::kTfLiteInt32: + LiteType = WASINN::TensorType::I32; + break; + default: + spdlog::error("[WASI-NN] Unsupported TFLite type: {}"sv, + TfLiteTypeGetName(Type)); + return WASINN::ErrNo::InvalidArgument; + } + + if (unlikely(LiteType != Tensor.RType)) { + spdlog::error("[WASI-NN] Expect tensor type {}, but got {}"sv, LiteType, + Tensor.RType); + return WASINN::ErrNo::InvalidArgument; + } + TfLiteStatus Stat = TfLiteTensorCopyFromBuffer( + HoldTensor, Tensor.Tensor.data(), Tensor.Tensor.size()); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WASI-NN] Copy tensor memory failed"sv); + return WASINN::ErrNo::Busy; + } + + return WASINN::ErrNo::Success; +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(CxtRef.TFLiteInterp); + if (Index >= OutCnt) { + spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " + "outputs are allowed"sv, + Index, OutCnt); + return WASINN::ErrNo::InvalidArgument; + } + const TfLiteTensor *HoldTensor = + TfLiteInterpreterGetOutputTensor(CxtRef.TFLiteInterp, Index); + const uint32_t BytesToWrite = TfLiteTensorByteSize(HoldTensor); + // Check out buffer max size. + if (OutBuffer.size() < BytesToWrite) { + spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}"sv, + BytesToWrite, OutBuffer.size()); + return WASINN::ErrNo::InvalidArgument; + } + TfLiteTensorCopyToBuffer(HoldTensor, OutBuffer.data(), BytesToWrite); + BytesWritten = BytesToWrite; + return WASINN::ErrNo::Success; +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + // Run session + if (unlikely(CxtRef.TFLiteInterp == nullptr)) { + spdlog::error("[WASI-NN] Tensorflow Lite context empty"sv); + return WASINN::ErrNo::MissingMemory; + } + TfLiteStatus Stat = TfLiteInterpreterInvoke(CxtRef.TFLiteInterp); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WASI-NN] Invocation failed."sv); + return WASINN::ErrNo::Busy; + } + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error( + "[WASI-NN] TensorflowLite backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::TensorflowLite diff --git a/plugins/wasi_nn/wasinn_tfl.h b/plugins/wasi_nn/wasinn_tfl.h new file mode 100644 index 00000000..e2a02419 --- /dev/null +++ b/plugins/wasi_nn/wasinn_tfl.h @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +#include "tensorflow/lite/c/c_api.h" +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::TensorflowLite { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +struct Graph { + ~Graph() noexcept { + if (TFLiteMod) { + TfLiteModelDelete(TFLiteMod); + } + } + std::vector TfLiteModData; + TfLiteModel *TFLiteMod = nullptr; +}; + +struct Context { +public: + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + ~Context() noexcept { + if (TFLiteInterp) { + TfLiteInterpreterDelete(TFLiteInterp); + } + } + uint32_t GraphId; + TfLiteInterpreter *TFLiteInterp = nullptr; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::TensorflowLite diff --git a/plugins/wasi_nn/wasinn_torch.cpp b/plugins/wasi_nn/wasinn_torch.cpp new file mode 100644 index 00000000..95cfc03c --- /dev/null +++ b/plugins/wasi_nn/wasinn_torch.cpp @@ -0,0 +1,335 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_torch.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +#include +#endif + +namespace WasmEdge::Host::WASINN::PyTorch { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + +Expect TorchScript::setDevice(Device Device) { + if (Device == Device::CPU) { + TorchDevice = at::kCPU; + return ErrNo::Success; + } else if (Device == Device::GPU) { + if (!torch::cuda::is_available()) { + spdlog::error("[WASI-NN] Torch: CUDA Unavailable. Please check if the " + "installed Torch version or platform supports CUDA."sv); + return ErrNo::InvalidArgument; + } + TorchDevice = at::kCUDA; + return ErrNo::Success; + } + + spdlog::error("[WASI-NN] Torch: Unknown target device. We currently support " + "only CPU and GPU targets."sv); + return ErrNo::InvalidArgument; +} + +Expect TorchScript::loadFromBinary(std::istream &In, Device Device) { + if (auto Err = setDevice(Device); Err != ErrNo::Success) { + return Err; + } + TorchModel = torch::jit::load(In); + return ErrNo::Success; +} + +Expect TorchScript::loadFromPath(const std::string &Path, + Device Device) { + if (auto Err = setDevice(Device); Err != ErrNo::Success) { + return Err; + } + TorchModel = torch::jit::load(Path); + return ErrNo::Success; +} + +Expect TorchScript::run(std::vector In, + std::vector &Out) { + std::vector Inputs; + std::vector Outputs; + for (auto &OneOf : In) { + Inputs.push_back(OneOf); + } + auto RawOutput = TorchModel.forward(Inputs); + if (RawOutput.isTensorList()) { + auto OutTensors = RawOutput.toTensorVector(); + for (auto &OneOf : OutTensors) { + Out.push_back(OneOf.clone()); + } + } else if (RawOutput.isTuple()) { + auto OutTensorsTuple = RawOutput.toTuple()->elements(); + for (auto &OneOf : OutTensorsTuple) { + Out.push_back(OneOf.toTensor().clone()); + } + } else if (RawOutput.isTensor()) { + auto OutTensor = RawOutput.toTensor(); + Out.push_back(OutTensor.clone()); + } else { + spdlog::error( + "[WASI-NN] Torch: The output can only be one of the following tensor " + "types: a tensor, a list of tensors, or a tuple of tensors."sv); + return ErrNo::InvalidArgument; + } + return ErrNo::Success; +} + +AOTInductor::AOTInductor() : TorchModel(nullptr) { +#if defined(_GLIBCXX_USE_CXX11_ABI) && _GLIBCXX_USE_CXX11_ABI == 1 + spdlog::warn( + "[WASI-NN] Torch: AOTInductor build by pip default is not supported in " + "_GLIBCXX_USE_CXX11_ABI=1. Please rebuild the WasmEdge with " + "_GLIBCXX_USE_CXX11_ABI=0."sv); +#endif +} + +Expect AOTInductor::setDevice(Device Device) { + if (Device == Device::CPU) { + TorchDevice = at::kCPU; + return ErrNo::Success; + } else if (Device == Device::GPU) { +#ifdef TORCHAOTI_USE_CUDA + TorchDevice = at::kCUDA; + return ErrNo::Success; +#else + spdlog::error("[WASI-NN] Torch: Please rebuild the plugin with AOTInductor " + "CUDA support."sv); + return ErrNo::InvalidArgument; +#endif + } + + spdlog::error("[WASI-NN] Torch: Unknown target device. We currently support " + "only CPU and GPU targets."sv); + return ErrNo::InvalidArgument; +} + +Expect AOTInductor::loadFromBinary(std::istream &, Device) { + spdlog::error( + "[WASI-NN] Torch: AOTInductor can not load by binary data. Please " + "pass the share library name (*.so) in nn-preload"sv); + return ErrNo::InvalidArgument; +} + +Expect AOTInductor::loadFromPath(const std::string &Path, + Device Device) { + if (auto Err = setDevice(Device); Err != ErrNo::Success) { + return Err; + } + if (TorchDevice == at::kCPU) { + TorchModel = new torch::inductor::AOTIModelContainerRunnerCpu(Path.c_str()); + } else if (TorchDevice == at::kCUDA) { +#ifdef TORCHAOTI_USE_CUDA + TorchModel = + new torch::inductor::AOTIModelContainerRunnerCuda(Path.c_str()); +#else + spdlog::error("[WASI-NN] Torch: Please rebuild the plugin with AOTInductor " + "CUDA support."sv); + return ErrNo::InvalidArgument; +#endif + } else { + spdlog::error("[WASI-NN] Torch: Can not load the AOTInductor."sv); + return ErrNo::InvalidArgument; + } + return ErrNo::Success; +} + +Expect AOTInductor::run(std::vector In, + std::vector &Out) { + std::vector RawOutput = TorchModel->run(In); + + for (auto &OneOf : RawOutput) { + Out.push_back(OneOf.clone()); + } + return ErrNo::Success; +} + +PyModelBackend guessPyModelBackendType(const std::string_view &Model) { + // TODO: Add more model type detection when we supporet more OS. + // ex .dll, .dylib, etc. + if (Model.substr(0, 8) == "preload:"sv) { + if (Model.substr(Model.size() - 3, 3) == ".so"sv) { + // AOTInductor only accept the shared library. + return PyModelBackend::AOTInductor; + } + } + + // ELF Header: 0x7f 'E' 'L' 'F' + if (Model.substr(0, 4) == "\x7f\x45\x4c\x46"sv) { + return PyModelBackend::AOTInductor; + } + + // Fall back to TorchScript if the model type is not set. + // This keep the compatibility with the old version. + return PyModelBackend::TorchScript; +} + +Expect load(WasiNNEnvironment &Env, Span> Builders, + Device Device, uint32_t &GraphId) noexcept { + // The graph builder length must be 1. + if (Builders.size() != 1) { + spdlog::error("[WASI-NN] Torch: Wrong GraphBuilder Length {:d}, expect 1"sv, + Builders.size()); + return ErrNo::InvalidArgument; + } + + auto Weight = Builders[0]; + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::PyTorch); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Load the model from the binary data. + // Note: Pytorch use try catch to handle the error. + try { + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + PyModelBackend ModelType = guessPyModelBackendType(BinModel); + + if (ModelType == PyModelBackend::TorchScript) { + GraphRef.Model = new TorchScript(); + } else if (ModelType == PyModelBackend::AOTInductor) { + GraphRef.Model = new AOTInductor(); + } else { + spdlog::error("[WASI-NN] Torch: Unknown model type."sv); + return ErrNo::InvalidArgument; + } + + if (BinModel.substr(0, 8) == "preload:"sv) { + const std::string ModelFilePath(BinModel.substr(8)); + GraphRef.Model->loadFromPath(ModelFilePath, Device); + } else { + std::istringstream BinRead{std::string(BinModel)}; + // std::istringstream BinRead(BinModel); // Need C++26... + GraphRef.Model->loadFromBinary(BinRead, Device); + } + } catch (const c10::Error &e) { + spdlog::error("[WASI-NN] Torch: Failed when load the TorchScript model."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + + GraphId = GId; + Env.NNGraph[GId].setReady(); + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + if (Index >= CxtRef.TorchInputs.size()) { + CxtRef.TorchInputs.resize(Index + 1); + } + if (Tensor.RType != TensorType::F32) { + spdlog::error( + "[WASI-NN] Torch: Only F32 inputs and outputs are supported for now."sv); + return ErrNo::InvalidArgument; + } + auto Options = + torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + std::vector Dims; + for (size_t I = 0; I < Tensor.Dimension.size(); I++) { + Dims.push_back(static_cast(Tensor.Dimension[I])); + } + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + torch::Tensor InTensor = + torch::from_blob(reinterpret_cast(Tensor.Tensor.data()), Dims, + Options) + .to(GraphRef.Model->getDevice()); + + CxtRef.TorchInputs[Index] = InTensor.clone(); + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + if (CxtRef.TorchOutputs.size() <= Index) { + spdlog::error( + "[WASI-NN] Torch: The output index {} exceeds the outputs number {}."sv, + Index, CxtRef.TorchOutputs.size()); + return ErrNo::InvalidArgument; + } + torch::Tensor OutTensor = + CxtRef.TorchOutputs[Index].to(at::kCPU).toType(torch::kFloat32); + float *TensorBuffer = OutTensor.data_ptr(); + + size_t BlobSize = 1; + for (auto I : OutTensor.sizes()) { + BlobSize *= I; + } + uint32_t BytesToWrite = + std::min(static_cast(BlobSize * 4), OutBuffer.size()); + std::copy_n(reinterpret_cast(TensorBuffer), BytesToWrite, + OutBuffer.data()); + BytesWritten = BytesToWrite; + return ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + if (CxtRef.TorchInputs.size() == 0) { + spdlog::error("[WASI-NN] Torch: Input is not set!"sv); + return ErrNo::InvalidArgument; + } + for (size_t I = 0; I < CxtRef.TorchInputs.size(); I++) { + torch::jit::IValue InTensor = CxtRef.TorchInputs[I]; + if (InTensor.isNone()) { + spdlog::error("[WASI-NN] Torch: Input [{}] is not set!"sv, I); + return ErrNo::InvalidArgument; + } + } + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + return GraphRef.Model->run(CxtRef.TorchInputs, CxtRef.TorchOutputs); +} + +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.Model) { + delete GraphRef.Model; + } + return ErrNo::Success; +} + +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] PyTorch backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::PyTorch diff --git a/plugins/wasi_nn/wasinn_torch.h b/plugins/wasi_nn/wasinn_torch.h new file mode 100644 index 00000000..8a9d8cde --- /dev/null +++ b/plugins/wasi_nn/wasinn_torch.h @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +#include +#include +#ifdef TORCHAOTI_USE_CUDA +#include +#endif +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::PyTorch { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + +class PyBaseModule { +public: + virtual ~PyBaseModule() = default; + virtual Expect setDevice(Device Device) = 0; + virtual Expect loadFromPath(const std::string &Path, + Device Device) = 0; + virtual Expect loadFromBinary(std::istream &In, Device Device) = 0; + virtual Expect run(std::vector In, + std::vector &Out) = 0; + + torch::DeviceType getDevice() const { return TorchDevice; } + +protected: + torch::DeviceType TorchDevice = at::kCPU; +}; + +class TorchScript : public PyBaseModule { + Expect setDevice(Device Device) override; + +public: + Expect loadFromPath(const std::string &Path, Device Device) override; + Expect loadFromBinary(std::istream &In, Device Device) override; + Expect run(std::vector In, + std::vector &Out) override; + + torch::jit::Module TorchModel; +}; + +class AOTInductor : public PyBaseModule { + Expect setDevice(Device Device) override; + +public: + AOTInductor(); + Expect loadFromPath(const std::string &Path, Device Device) override; + Expect loadFromBinary(std::istream &In, Device Device) override; + Expect run(std::vector In, + std::vector &Out) override; + + torch::inductor::AOTIModelContainerRunner *TorchModel; + + ~AOTInductor() { + if (TorchModel) { + delete TorchModel; + } + } +}; + +enum class PyModelBackend { TorchScript, AOTInductor, UNKNOWN }; + +struct Graph { + PyBaseModule *Model = nullptr; +}; + +struct Context { +public: + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + uint32_t GraphId; + std::vector TorchInputs; + std::vector TorchOutputs; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +} // namespace WasmEdge::Host::WASINN::PyTorch diff --git a/plugins/wasi_nn/wasinn_whisper.cpp b/plugins/wasi_nn/wasinn_whisper.cpp new file mode 100644 index 00000000..f82bd29d --- /dev/null +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -0,0 +1,1160 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_whisper.h" +#include "host/wasi/vfs_io.h" +#include "wasinnenv.h" +#include +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER +#define DR_WAV_IMPLEMENTATION +#include "simdjson.h" +#include + +#include +#endif + +using namespace std::literals; + +namespace WasmEdge::Host::WASINN::Whisper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER + +namespace { +int timestampToSample(int64_t T, int NSamples, int WhisperSampleRate) { + return std::max(0, std::min(static_cast(NSamples) - 1, + static_cast((T * WhisperSampleRate) / 100))); +} +std::string toTimestamp(int64_t T, bool Comma) { + int64_t Msec = T * 10; + int64_t Hr = Msec / (1000 * 60 * 60); + Msec = Msec - Hr * (1000 * 60 * 60); + int64_t Min = Msec / (1000 * 60); + Msec = Msec - Min * (1000 * 60); + int64_t Sec = Msec / 1000; + Msec = Msec - Sec * 1000; + + char Buf[32] = {}; + snprintf(Buf, sizeof(Buf), "%02d:%02d:%02d%s%03d", static_cast(Hr), + static_cast(Min), static_cast(Sec), Comma ? "," : ".", + static_cast(Msec)); + + return std::string(Buf); +} + +std::string +estimateDiarizationSpeaker(const std::vector> PCMF32s, + int64_t T0, int64_t T1, bool IdOnly = false) { + std::string Speaker = ""; + const int64_t NSamples = PCMF32s[0].size(); + + const int64_t Is0 = timestampToSample(T0, NSamples, WHISPER_SAMPLE_RATE); + const int64_t Is1 = timestampToSample(T1, NSamples, WHISPER_SAMPLE_RATE); + + double Energy0 = 0.0f; + double Energy1 = 0.0f; + + for (int64_t I = Is0; I < Is1; I++) { + Energy0 += fabs(PCMF32s[0][I]); + Energy1 += fabs(PCMF32s[1][I]); + } + + if (Energy0 > 1.1 * Energy1) { + Speaker = "0"; + } else if (Energy1 > 1.1 * Energy0) { + Speaker = "1"; + } else { + Speaker = "?"; + } + + if (!IdOnly) { + Speaker.insert(0, "(speaker "); + Speaker.append(")"); + } + + return Speaker; +} + +bool outputSrt(WasiNNEnvironment &Env, whisper_context *Ctx, + const std::string &Fname, const Config &Params, + const std::vector> &PCMF32s) { + WasmEdge::FStream::OFStream Fout(Fname, Env.getEnv()); + if (!Fout.is_open()) { + spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, + Fname); + return false; + } + spdlog::info("[WASI-NN] Whisper backend: saving srt output to {}."sv, Fname); + + const int NSegments = whisper_full_n_segments(Ctx); + for (int I = 0; I < NSegments; ++I) { + const std::string &Text = whisper_full_get_segment_text(Ctx, I); + const int64_t T0 = whisper_full_get_segment_t0(Ctx, I); + const int64_t T1 = whisper_full_get_segment_t1(Ctx, I); + std::string Speaker = ""; + + if (Params.Diarize && PCMF32s.size() == 2) { + Speaker = estimateDiarizationSpeaker(PCMF32s, T0, T1); + } + + Fout << I + 1 + Params.OffsetN << "\n"; + Fout << toTimestamp(T0, true) << " --> " << toTimestamp(T1, true) << "\n"; + Fout << Speaker << Text << "\n\n"; + } + return true; +} + +static bool outputLrc(WasiNNEnvironment &Env, whisper_context *Ctx, + const std::string &Fname, const Config &Params, + const std::vector> &PCMF32s) { + WasmEdge::FStream::OFStream Fout(Fname, Env.getEnv()); + if (!Fout.is_open()) { + spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, + Fname); + return false; + } + + spdlog::info("[WASI-NN] Whisper backend: saving lrc output to {}."sv, Fname); + + Fout << "[by:whisper.cpp]\n"; + + const int NSegments = whisper_full_n_segments(Ctx); + for (int I = 0; I < NSegments; ++I) { + const std::string &text = whisper_full_get_segment_text(Ctx, I); + const int64_t T = whisper_full_get_segment_t0(Ctx, I); + + int64_t Msec = T * 10; + int64_t Min = Msec / (1000 * 60); + Msec = Msec - Min * (1000 * 60); + int64_t Sec = Msec / 1000; + Msec = Msec - Sec * 1000; + + char Buf[16]; + snprintf(Buf, sizeof(Buf), "%02d:%02d.%02d", static_cast(Min), + static_cast(Sec), static_cast((Msec / 10))); + std::string TimestampLrc = std::string(Buf); + std::string Speaker = ""; + + if (Params.Diarize && PCMF32s.size() == 2) { + const int64_t t0 = whisper_full_get_segment_t0(Ctx, I); + const int64_t t1 = whisper_full_get_segment_t1(Ctx, I); + Speaker = estimateDiarizationSpeaker(PCMF32s, t0, t1); + } + + Fout << '[' << TimestampLrc << ']' << Speaker << text << "\n"; + } + + return true; +} + +std::string escapeDoubleQuotesAndBackslashes(const std::string &Str) { + std::string Escaped; + for (auto W : Str) { + if (W == '"' || W == '\\') { + Escaped += '\\'; + } + Escaped += W; + } + return Escaped; +} + +bool outputJson(WasiNNEnvironment &Env, whisper_context *Ctx, + const std::string &Fname, const Config &Params, + const std::vector> &PCMF32s, bool Full) { + WasmEdge::FStream::OFStream Fout(Fname, Env.getEnv()); + int Indent = 0; + + auto Doindent = [&]() { + for (int i = 0; i < Indent; i++) + Fout << "\t"; + }; + + auto StartArr = [&](const char *Name) { + Doindent(); + Fout << "\"" << Name << "\": [\n"; + Indent++; + }; + + auto EndArr = [&](bool End) { + Indent--; + Doindent(); + Fout << (End ? "]\n" : "],\n"); + }; + + auto StartObj = [&](const char *Name) { + Doindent(); + if (Name) { + Fout << "\"" << Name << "\": {\n"; + } else { + Fout << "{\n"; + } + Indent++; + }; + + auto EndObj = [&](bool End) { + Indent--; + Doindent(); + Fout << (End ? "}\n" : "},\n"); + }; + + auto StartValue = [&](const char *Name) { + Doindent(); + Fout << "\"" << Name << "\": "; + }; + + auto ValueS = [&](const char *Name, const std::string &Val, bool End) { + StartValue(Name); + std::string ValEscaped = escapeDoubleQuotesAndBackslashes(Val); + Fout << "\"" << ValEscaped << (End ? "\"\n" : "\",\n"); + }; + + auto EndValue = [&](bool End) { Fout << (End ? "\n" : ",\n"); }; + + auto ValueI = [&](const char *Name, const int64_t Val, bool End) { + StartValue(Name); + Fout << Val; + EndValue(End); + }; + + auto ValueF = [&](const char *Name, const float Val, bool End) { + StartValue(Name); + Fout << Val; + EndValue(End); + }; + + auto ValueB = [&](const char *Name, const bool Val, bool End) { + StartValue(Name); + Fout << (Val ? "true" : "false"); + EndValue(End); + }; + + auto TimesO = [&](int64_t T0, int64_t T1, bool End) { + StartObj("timestamps"); + ValueS("from", toTimestamp(T0, true), false); + ValueS("to", toTimestamp(T1, true), true); + EndObj(false); + StartObj("offsets"); + ValueI("from", T0 * 10, false); + ValueI("to", T1 * 10, true); + EndObj(End); + }; + + if (!Fout.is_open()) { + spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, + Fname); + return false; + } + + spdlog::info("[WASI-NN] Whisper backend: saving json output to {}."sv, Fname); + + StartObj(nullptr); + ValueS("systeminfo", whisper_print_system_info(), false); + StartObj("model"); + ValueS("type", whisper_model_type_readable(Ctx), false); + ValueB("multilingual", whisper_is_multilingual(Ctx), false); + ValueI("vocab", whisper_model_n_vocab(Ctx), false); + StartObj("audio"); + ValueI("ctx", whisper_model_n_audio_ctx(Ctx), false); + ValueI("state", whisper_model_n_audio_state(Ctx), false); + ValueI("head", whisper_model_n_audio_head(Ctx), false); + ValueI("layer", whisper_model_n_audio_layer(Ctx), true); + EndObj(false); + StartObj("text"); + ValueI("ctx", whisper_model_n_text_ctx(Ctx), false); + ValueI("state", whisper_model_n_text_state(Ctx), false); + ValueI("head", whisper_model_n_text_head(Ctx), false); + ValueI("layer", whisper_model_n_text_layer(Ctx), true); + EndObj(false); + ValueI("mels", whisper_model_n_mels(Ctx), false); + ValueI("ftype", whisper_model_ftype(Ctx), true); + EndObj(false); + StartObj("params"); + ValueS("model", "Wasi-nn preload", false); + ValueS("language", Params.SpokenLanguage, false); + ValueB("translate", Params.Translate, true); + EndObj(false); + StartObj("result"); + ValueS("language", whisper_lang_str(whisper_full_lang_id(Ctx)), true); + EndObj(false); + StartArr("transcription"); + + const int NSegments = whisper_full_n_segments(Ctx); + for (int I = 0; I < NSegments; ++I) { + const std::string &Text = whisper_full_get_segment_text(Ctx, I); + + const int64_t T0 = whisper_full_get_segment_t0(Ctx, I); + const int64_t T1 = whisper_full_get_segment_t1(Ctx, I); + + StartObj(nullptr); + TimesO(T0, T1, false); + ValueS("text", Text, !Params.Diarize && !Params.TinyDiarize && !Full); + + if (Full) { + StartArr("tokens"); + const int n = whisper_full_n_tokens(Ctx, I); + for (int j = 0; j < n; ++j) { + auto token = whisper_full_get_token_data(Ctx, I, j); + StartObj(nullptr); + ValueS("text", whisper_token_to_str(Ctx, token.id), false); + if (token.t0 > -1 && token.t1 > -1) { + // If we have per-token timestamps, write them out + TimesO(token.t0, token.t1, false); + } + ValueI("id", token.id, false); + ValueF("p", token.p, false); + ValueF("t_dtw", token.t_dtw, true); + EndObj(j == (n - 1)); + } + EndArr(!Params.Diarize && !Params.TinyDiarize); + } + + if (Params.Diarize && PCMF32s.size() == 2) { + ValueS("speaker", estimateDiarizationSpeaker(PCMF32s, T0, T1, true), + true); + } + + if (Params.TinyDiarize) { + ValueB("speaker_turn_next", + whisper_full_get_segment_speaker_turn_next(Ctx, I), true); + } + EndObj(I == (NSegments - 1)); + } + + EndArr(true); + EndObj(true); + return true; +} + +bool checkAudioRIFF(const std::string_view Buf, const std::string_view Format) { + if (Buf.size() < 12 || Buf.substr(0, 4) != "RIFF"sv) { + return false; + } + if (Buf.substr(8, 4) != Format) { + return false; + } + uint32_t ChunkSize = *reinterpret_cast(Buf.data() + 4); + if (ChunkSize + 8 != Buf.size()) { + return false; + } + return true; +} + +bool loadWAV(Span Buf, std::vector &PCMF32, + std::vector> &PCMF32s, bool Stereo) { + // Do not use the helper function from whisper.cpp examples to avoid copying. + drwav WAV; + const uint32_t ConstSampleRate = 16000; + + if (!drwav_init_memory(&WAV, Buf.data(), Buf.size(), nullptr)) { + spdlog::error("[WASI-NN] Whisper backend: load WAV failed."sv); + return false; + } + + if (WAV.channels != 1 && WAV.channels != 2) { + spdlog::error("[WASI-NN] Whisper backend: WAV must be mono or stereo."sv); + drwav_uninit(&WAV); + return false; + } + + if (WAV.sampleRate != ConstSampleRate) { + spdlog::error("[WASI-NN] Whisper backend: WAV must be {} kHz."sv, + ConstSampleRate / 1000); + drwav_uninit(&WAV); + return false; + } + + if (WAV.bitsPerSample != 16) { + spdlog::error("[WASI-NN] Whisper backend: WAV must be 16-bit."sv); + drwav_uninit(&WAV); + return false; + } + + const uint32_t N = WAV.totalPCMFrameCount; + std::vector PCM16(N * WAV.channels); + drwav_read_pcm_frames_s16(&WAV, N, PCM16.data()); + drwav_uninit(&WAV); + + PCMF32.resize(N); + if (WAV.channels == 1) { + for (uint64_t I = 0; I < N; I++) { + PCMF32[I] = static_cast(PCM16[I]) / 32768.0f; + } + } else { + for (uint64_t I = 0; I < N; I++) { + PCMF32[I] = + static_cast(PCM16[2 * I] + PCM16[2 * I + 1]) / 65536.0f; + } + } + if (Stereo) { + PCMF32s.resize(2); + + PCMF32s[0].resize(N); + PCMF32s[1].resize(N); + for (uint64_t I = 0; I < N; I++) { + PCMF32s[0][I] = float(PCM16[2 * I]) / 32768.0f; + PCMF32s[1][I] = float(PCM16[2 * I + 1]) / 32768.0f; + } + } + return true; +} + +void WhisperLogCallback(ggml_log_level LogLevel, const char *LogText, + void *UserData) { + const Graph &GraphRef = *reinterpret_cast(UserData); + if (!GraphRef.WhisperConfig.EnableLog) { + return; + } + std::string Text(LogText); + // Remove the trailing newlines. + Text = Text.erase(Text.find_last_not_of("\n") + 1); + // Skip for "." + if (Text == ".") { + return; + } + if (LogLevel == GGML_LOG_LEVEL_ERROR) { + spdlog::error("[WASI-NN] whisper.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_WARN) { + spdlog::warn("[WASI-NN] whisper.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_INFO) { + spdlog::info("[WASI-NN] whisper.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { + spdlog::debug("[WASI-NN] whisper.cpp: {}"sv, Text); + } +} + +void WhisperOutputSegmentCallback(struct whisper_context *WhisperCtx, + struct whisper_state * /* state */, int NewN, + void *UserData) { + auto &CxtRef = *reinterpret_cast(UserData); + const int SegN = whisper_full_n_segments(WhisperCtx); + + std::string Speaker = ""; + // Output the last new N segments. + for (int I = SegN - NewN; I < SegN; I++) { + int64_t T0 = 0; + int64_t T1 = 0; + if (!CxtRef.WhisperConfig.NoTimestamps) { + T0 = whisper_full_get_segment_t0(WhisperCtx, I); + T1 = whisper_full_get_segment_t1(WhisperCtx, I); + CxtRef.Outputs += "["; + CxtRef.Outputs += toTimestamp(T0, false); + CxtRef.Outputs += " --> "; + CxtRef.Outputs += toTimestamp(T1, false); + CxtRef.Outputs += "] "; + } + if (CxtRef.WhisperConfig.Diarize && CxtRef.InputPCMs.size() == 2) { + Speaker = estimateDiarizationSpeaker(CxtRef.InputPCMs, T0, T1); + } + CxtRef.Outputs += Speaker + whisper_full_get_segment_text(WhisperCtx, I); + if (!CxtRef.WhisperConfig.NoTimestamps || CxtRef.WhisperConfig.Diarize) { + CxtRef.Outputs += "\n"; + } + } +} + +void setWhisperParams(Context &CxtRef) noexcept { + auto &WParam = CxtRef.WhisperParams; + auto &ConfigRef = CxtRef.WhisperConfig; + WParam.n_threads = ConfigRef.ThreadsNum; + WParam.n_max_text_ctx = ConfigRef.MaxTokenContext; + WParam.offset_ms = ConfigRef.TimeOffsetMS; + WParam.duration_ms = ConfigRef.DurationMS; + WParam.print_progress = false; + WParam.thold_pt = ConfigRef.WordThreshold; + WParam.max_len = ConfigRef.MaxSegmentLength; + WParam.token_timestamps = (WParam.max_len > 0); + WParam.split_on_word = ConfigRef.SplitOnWord; + WParam.translate = ConfigRef.Translate; + WParam.language = ConfigRef.SpokenLanguage.c_str(); + WParam.detect_language = ConfigRef.DetectLanguage; + WParam.initial_prompt = ConfigRef.InitialPrompt.c_str(); + WParam.temperature_inc = ConfigRef.TemperatureInc; + WParam.temperature = ConfigRef.Temperature; + WParam.entropy_thold = ConfigRef.EntropyThreshold; + WParam.logprob_thold = ConfigRef.LogprobThreshold; + WParam.grammar_penalty = ConfigRef.GrammarPenalty; + WParam.new_segment_callback = WhisperOutputSegmentCallback; + WParam.new_segment_callback_user_data = &CxtRef; + WParam.greedy.best_of = ConfigRef.BestOf; + WParam.print_timestamps = !ConfigRef.NoTimestamps; + WParam.no_timestamps = ConfigRef.NoTimestamps; + WParam.audio_ctx = ConfigRef.AudioCtx; + WParam.strategy = + (ConfigRef.BeamSize > 1) + ? whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH + : whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY; + WParam.beam_search.beam_size = ConfigRef.BeamSize; + + if (ConfigRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: threads: {}"sv, + ConfigRef.ThreadsNum); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: processors: {}"sv, + ConfigRef.ProcessorsNum); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-context: {}"sv, + ConfigRef.MaxTokenContext); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: offset-t: {}"sv, + ConfigRef.TimeOffsetMS); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: duration: {}"sv, + ConfigRef.DurationMS); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-len: {}"sv, + ConfigRef.MaxSegmentLength); + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Config: split-on-word : {}"sv, + ConfigRef.SplitOnWord); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: translate: {}"sv, + ConfigRef.Translate); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: language: \"{}\""sv, + ConfigRef.SpokenLanguage); + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Config: detect-language: {}"sv, + ConfigRef.DetectLanguage); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: temperature: {}"sv, + ConfigRef.Temperature); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: prompt: \"{}\""sv, + ConfigRef.InitialPrompt); + } +} + +Expect parseMetadata(Config &ConfigRef, + const std::string &Metadata) noexcept { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] Whisper backend: Parse metadata error."sv); + return ErrNo::InvalidEncoding; + } + + auto PrintParsedOption = [&](std::string_view Name, const auto &Val) { + if (ConfigRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Parsed metadata -- {}:{}"sv, Name, + Val); + } + }; + + // Get metadata from the json. + // Currently supported metadata: + // Plugin parameters (used by this plugin): + // enable-log: bool + // enable-debug-log: bool + // threads: uint32_t + // processors: uint32_t + // offset-t: uint32_t + // duration: uint32_t + // max-context: uint32_t + // max-len: uint32_t + // split-on-word: bool + // translate: bool + // language: string + // detect-language: bool + // temperature: float + // prompt: string + + // The plugin parameters. + if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-log"].get().get(ConfigRef.EnableLog); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the enable-log " + "option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { + auto Err = + Doc["enable-debug-log"].get().get(ConfigRef.EnableDebugLog); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the enable-debug-log " + "option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("threads").error() == simdjson::SUCCESS) { + auto Err = Doc["threads"].get().get(ConfigRef.ThreadsNum); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the threads option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("threads"sv, ConfigRef.ThreadsNum); + } + if (Doc.at_key("processors").error() == simdjson::SUCCESS) { + auto Err = Doc["processors"].get().get(ConfigRef.ProcessorsNum); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the processors option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("processors"sv, ConfigRef.ProcessorsNum); + } + if (Doc.at_key("offset-t").error() == simdjson::SUCCESS) { + auto Err = Doc["offset-t"].get().get(ConfigRef.TimeOffsetMS); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the offset-t option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("offset-t"sv, ConfigRef.TimeOffsetMS); + } + if (Doc.at_key("duration").error() == simdjson::SUCCESS) { + auto Err = Doc["duration"].get().get(ConfigRef.DurationMS); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the duration option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("duration"sv, ConfigRef.DurationMS); + } + if (Doc.at_key("max-context").error() == simdjson::SUCCESS) { + int64_t MaxContext = 0; + auto Err = Doc["max-context"].get().get(MaxContext); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the max-context option."sv); + return ErrNo::InvalidArgument; + } + if (MaxContext >= 0) { + ConfigRef.MaxTokenContext = static_cast(MaxContext); + PrintParsedOption("max-context"sv, ConfigRef.MaxTokenContext); + } + } + if (Doc.at_key("max-len").error() == simdjson::SUCCESS) { + auto Err = Doc["max-len"].get().get(ConfigRef.MaxSegmentLength); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the max-len option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("max-len"sv, ConfigRef.MaxSegmentLength); + } + if (Doc.at_key("split-on-word").error() == simdjson::SUCCESS) { + auto Err = Doc["split-on-word"].get().get(ConfigRef.SplitOnWord); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the split-on-word " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("split-on-word"sv, ConfigRef.SplitOnWord); + } + if (Doc.at_key("translate").error() == simdjson::SUCCESS) { + auto Err = Doc["translate"].get().get(ConfigRef.Translate); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the translate " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("translate"sv, ConfigRef.Translate); + } + if (Doc.at_key("language").error() == simdjson::SUCCESS) { + std::string_view Language; + auto Err = Doc["language"].get().get(Language); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the language " + "option."sv); + return ErrNo::InvalidArgument; + } + ConfigRef.SpokenLanguage = Language; + PrintParsedOption("language"sv, ConfigRef.SpokenLanguage); + } + if (Doc.at_key("detect-language").error() == simdjson::SUCCESS) { + auto Err = Doc["detect-language"].get().get(ConfigRef.DetectLanguage); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the detect-language " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("detect-language"sv, ConfigRef.DetectLanguage); + } + if (Doc.at_key("temperature").error() == simdjson::SUCCESS) { + double Temperature; + auto Err = Doc["temperature"].get().get(Temperature); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the temperature option."sv); + return ErrNo::InvalidArgument; + } + ConfigRef.Temperature = static_cast(Temperature); + PrintParsedOption("temperature"sv, ConfigRef.Temperature); + } + if (Doc.at_key("prompt").error() == simdjson::SUCCESS) { + std::string_view Prompt; + auto Err = Doc["prompt"].get().get(Prompt); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the prompt option."sv); + return ErrNo::InvalidArgument; + } + ConfigRef.InitialPrompt = Prompt; + PrintParsedOption("prompt"sv, ConfigRef.InitialPrompt); + } + if (Doc.at_key("best-of").error() == simdjson::SUCCESS) { + auto Err = Doc["best-of"].get().get(ConfigRef.BestOf); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the best-of option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("best-of"sv, ConfigRef.BestOf); + } + if (Doc.at_key("beam-size").error() == simdjson::SUCCESS) { + auto Err = Doc["beam-size"].get().get(ConfigRef.BeamSize); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the beam-size option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("beam-size"sv, ConfigRef.BeamSize); + } + if (Doc.at_key("output-srt").error() == simdjson::SUCCESS) { + auto Err = Doc["output-srt"].get().get(ConfigRef.OutputSrt); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-srt " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-srt"sv, ConfigRef.OutputLrc); + } + if (Doc.at_key("output-lrc").error() == simdjson::SUCCESS) { + auto Err = Doc["output-lrc"].get().get(ConfigRef.OutputLrc); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-lrc" + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-lrc"sv, ConfigRef.OutputLrc); + } + if (Doc.at_key("output-json").error() == simdjson::SUCCESS) { + auto Err = Doc["output-json"].get().get(ConfigRef.OutputJson); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-json " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-json"sv, ConfigRef.OutputJson); + } + if (Doc.at_key("output-json-full").error() == simdjson::SUCCESS) { + auto Err = + Doc["output-json-full"].get().get(ConfigRef.OutputJsonFull); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-json-full " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-json-full"sv, ConfigRef.OutputJsonFull); + } + if (Doc.at_key("no-timestamps").error() == simdjson::SUCCESS) { + auto Err = Doc["no-timestamps"].get().get(ConfigRef.NoTimestamps); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the no-timestamps " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("no-timestamps"sv, ConfigRef.NoTimestamps); + } + if (Doc.at_key("output-file").error() == simdjson::SUCCESS) { + std::string_view FileName; + auto Err = Doc["output-file"].get().get(FileName); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output file" + "option."sv); + return ErrNo::InvalidArgument; + } + ConfigRef.FileName = FileName; + PrintParsedOption("output-file"sv, ConfigRef.FileName); + } + if (Doc.at_key("audio-ctx").error() == simdjson::SUCCESS) { + auto Err = Doc["audio-ctx"].get().get(ConfigRef.AudioCtx); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the audio-ctx " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("audio-ctx"sv, ConfigRef.AudioCtx); + } + if (Doc.at_key("diarize").error() == simdjson::SUCCESS) { + auto Err = Doc["diarize"].get().get(ConfigRef.Diarize); + if (Err) { + spdlog::error("[WASI-NN] Whisper backend: Unable to retrieve the diarize " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("diarize"sv, ConfigRef.Diarize); + } + if (Doc.at_key("offset-n").error() == simdjson::SUCCESS) { + auto Err = Doc["offset-n"].get().get(ConfigRef.OffsetN); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the offset-n " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("offset-n"sv, ConfigRef.OffsetN); + } + + return ErrNo::Success; +} + +Expect handleTranslationConfig(whisper_context *WhisperCtx, + Config &ConfigRef) noexcept { + assuming(WhisperCtx); + + // Check the language. + if (ConfigRef.SpokenLanguage != "auto"sv && + whisper_lang_id(ConfigRef.SpokenLanguage.c_str()) == -1) { + spdlog::error("[WASI-NN] Whisper backend: Error: unknown language {}."sv, + ConfigRef.SpokenLanguage); + return ErrNo::InvalidArgument; + } + + // Check the translate option. + if (!whisper_is_multilingual(WhisperCtx)) { + if (ConfigRef.SpokenLanguage != "en"sv || ConfigRef.Translate) { + ConfigRef.SpokenLanguage = "en"sv; + ConfigRef.Translate = false; + if (ConfigRef.EnableLog) { + spdlog::info( + "[WASI-NN] Whisper backend: Model is not multilingual. Ignoring " + "language and translation options"sv); + } + } + } + if (ConfigRef.DetectLanguage) { + ConfigRef.SpokenLanguage = "auto"sv; + } + return ErrNo::Success; +} + +} // Namespace + +Expect load(WasiNNEnvironment &Env, Span> Builders, + [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::Whisper); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Initialize the parameters. + auto CParam = whisper_context_default_params(); + GraphRef.ModelFilePath = ""sv; + GraphRef.WhisperConfig.SpokenLanguage = "en"sv; + GraphRef.UseGPU = CParam.use_gpu; + GraphRef.MainGPU = CParam.gpu_device; + + // Set whisper log callback. + whisper_log_set(WhisperLogCallback, &GraphRef); + + // If the graph builder length is greater than 1, builder[1] contains the + // metadata. + if (Builders.size() > 1) { + const std::string Metadata(reinterpret_cast(Builders[1].data()), + Builders[1].size()); + // Ignore context or model updates when initializing the graph. + auto Res = parseMetadata(GraphRef.WhisperConfig, Metadata); + if (Res != ErrNo::Success) { + spdlog::error("[WASI-NN] Whisper backend: Failed to parse metadata."sv); + Env.deleteGraph(GId); + return Res; + } + } + + // Handle the model path. + if (GraphRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: Handling model path."sv); + } + auto Weight = Builders[0]; + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + if (BinModel.substr(0, 8) == "preload:"sv) { + GraphRef.ModelFilePath = BinModel.substr(8); + } + + // Initialize whisper context from model file with parameters. + if (GraphRef.WhisperConfig.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Initialize whisper context with " + "given parameters"sv); + } + if (GraphRef.ModelFilePath == ""sv) { + GraphRef.WhisperCtx = whisper_init_from_buffer_with_params( + Weight.data(), Weight.size(), CParam); + } else { + GraphRef.WhisperCtx = whisper_init_from_file_with_params( + GraphRef.ModelFilePath.c_str(), CParam); + } + if (GraphRef.WhisperCtx == nullptr) { + spdlog::error( + "[WASI-NN] Whisper backend: Error: unable to init whisper context from " + "model."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + if (GraphRef.WhisperConfig.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Initialize whisper context with " + "given parameters...Done"sv); + } + + auto ResTranslateConfig = + handleTranslationConfig(GraphRef.WhisperCtx, GraphRef.WhisperConfig); + if (ResTranslateConfig != ErrNo::Success) { + Env.deleteGraph(GId); + return ResTranslateConfig; + } + + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx"sv); + } + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + auto &CxtRef = Env.NNContext[ContextId].get(); + CxtRef.WhisperParams = whisper_full_default_params( + whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH); + setWhisperParams(CxtRef); + if (GraphRef.WhisperConfig.EnableLog) { + spdlog::info("[WASI-NN] Whisper backend: whisper_system_info: {}"sv, + whisper_print_system_info()); + } + Env.NNContext[ContextId].setReady(); + if (GraphRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx...Done"sv); + } + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index [[maybe_unused]], + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: setInput"sv); + } + + // Use index 1 for metadata. + if (Index == 1) { + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: found Metadata, processing"sv); + } + // Set the whisper config of this context as the graph default first. + // This will reset the config and inherit settings from the graph metadata. + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + CxtRef.WhisperConfig = GraphRef.WhisperConfig; + const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + auto Res = parseMetadata(CxtRef.WhisperConfig, Metadata); + if (Res != ErrNo::Success) { + spdlog::error("[WASI-NN] Whisper backend: Failed to parse metadata."sv); + return Res; + } + Res = handleTranslationConfig(GraphRef.WhisperCtx, CxtRef.WhisperConfig); + if (Res != ErrNo::Success) { + return Res; + } + setWhisperParams(CxtRef); + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: found Metadata, " + "processing...Done"sv); + } + return ErrNo::Success; + } + + if (Tensor.Dimension.size() != 2) { + spdlog::error("[WASI-NN] Tensor dimension is out of range, expect 2-dim, " + "but got {}-dim."sv, + Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + if (Tensor.Dimension[0] != 1) { + spdlog::error("[WASI-NN] Only 1 channel supported for now."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Tensor type not used here. Not to check this. + + // Check the input audio file format and load. Currently WAV supported. + if (!checkAudioRIFF( + std::string_view(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()), + "WAVE"sv)) { + spdlog::error("[WASI-NN] Only WAV format supported now."sv); + return WASINN::ErrNo::InvalidArgument; + } + if (!loadWAV(Tensor.Tensor, CxtRef.InputPCM, CxtRef.InputPCMs, + CxtRef.WhisperConfig.Diarize)) { + return WASINN::ErrNo::InvalidArgument; + } + + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: setInput...Done"sv); + } + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: getOutput with Index {}"sv, + Index); + } + + // Check out buffer max size. + if (OutBuffer.size() < CxtRef.Outputs.length()) { + spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}"sv, + CxtRef.Outputs.length(), OutBuffer.size()); + return WASINN::ErrNo::InvalidArgument; + } + + std::copy_n(CxtRef.Outputs.data(), CxtRef.Outputs.length(), OutBuffer.data()); + BytesWritten = CxtRef.Outputs.length(); + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: getOutput with Index {}...Done"sv, + Index); + } + + if (CxtRef.WhisperConfig.OutputSrt) { + const auto Fname = CxtRef.WhisperConfig.FileName + ".srt"; + outputSrt(Env, GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + CxtRef.InputPCMs); + } + + if (CxtRef.WhisperConfig.OutputLrc) { + const auto Fname = CxtRef.WhisperConfig.FileName + ".lrc"; + outputLrc(Env, GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + CxtRef.InputPCMs); + } + + if (CxtRef.WhisperConfig.OutputJson) { + const auto Fname = CxtRef.WhisperConfig.FileName + ".json"; + outputJson(Env, GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + CxtRef.InputPCMs, CxtRef.WhisperConfig.OutputJsonFull); + } + + return ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: compute"sv); + } + + CxtRef.Outputs.clear(); + if (whisper_full_parallel(GraphRef.WhisperCtx, CxtRef.WhisperParams, + CxtRef.InputPCM.data(), CxtRef.InputPCM.size(), + CxtRef.WhisperConfig.ProcessorsNum) != 0) { + spdlog::error( + "[WASI-NN] Whisper backend: Error: failed to process audio."sv); + return ErrNo::RuntimeError; + } + + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: compute...Done"sv); + } + return ErrNo::Success; +} + +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + const bool IsDebugLog = GraphRef.WhisperConfig.EnableDebugLog; + if (IsDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: unload"sv); + } + if (GraphRef.WhisperCtx != nullptr) { + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: unload: free whisper context"sv); + } + whisper_free(GraphRef.WhisperCtx); + GraphRef.WhisperCtx = nullptr; + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: unload: free whisper context...Done"sv); + } + } + Env.deleteGraph(GraphId); + Env.mdRemoveById(GraphId); + if (IsDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: unload...Done"sv); + } + return ErrNo::Success; +} + +Expect finalizeExecCtx(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + const bool IsDebugLog = CxtRef.WhisperConfig.EnableDebugLog; + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: finalize_execution_context"sv); + } + // TODO: Free resources + Env.deleteContext(ContextId); + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: finalize_execution_context...Done"sv); + } + return ErrNo::Success; +} +#else + +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Whisper backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"whisper\" to build it."sv); + return ErrNo::InvalidArgument; +} +} // Namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // Namespace WasmEdge::Host::WASINN::Whisper diff --git a/plugins/wasi_nn/wasinn_whisper.h b/plugins/wasi_nn/wasinn_whisper.h new file mode 100644 index 00000000..64745040 --- /dev/null +++ b/plugins/wasi_nn/wasinn_whisper.h @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER +#include + +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::Whisper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER + +struct Config { + // Whisper parameters: + uint64_t ThreadsNum = + std::min(static_cast(4), + static_cast(std::thread::hardware_concurrency())); + uint64_t ProcessorsNum = 1; + uint64_t MaxTokenContext = 16384; + uint64_t TimeOffsetMS = 0; + uint64_t DurationMS = 0; + uint64_t MaxSegmentLength = 0; + bool EnableLog = false; + bool EnableDebugLog = false; + bool Translate = false; + bool DetectLanguage = false; + bool SplitOnWord = false; + bool Diarize = false; + bool TinyDiarize = false; + std::string SpokenLanguage; + std::string InitialPrompt; + uint64_t BestOf = 5; + uint64_t BeamSize = 5; + uint64_t OffsetN = 0; + std::string FileName = "output"; + bool OutputSrt = false; + bool OutputLrc = false; + bool OutputJson = false; + bool OutputJsonFull = false; + bool NoTimestamps = false; + uint64_t AudioCtx = 0; + // Sampling parameters: + float WordThreshold = 0.01f; + float EntropyThreshold = 2.40f; + float LogprobThreshold = -1.00f; + float Temperature = 0.0f; + float TemperatureInc = 0.2f; + float GrammarPenalty = 100.0f; +}; + +struct Graph { + whisper_context *WhisperCtx = nullptr; + std::string ModelFilePath; + // Whisper config: + Config WhisperConfig; + // Context parameters: + bool UseGPU = true; + int64_t MainGPU = 0; // Use GPU 0 by default +}; + +struct Context { +public: + Context(uint32_t GId, Graph &G) noexcept + : GraphId(GId), WhisperConfig(G.WhisperConfig) {} + uint32_t GraphId; + // mono-channel F32 PCM input. + std::vector InputPCM; + std::vector> InputPCMs; + // Whisper config. Inherited from the graph and updated from metadata when + // setting input. + Config WhisperConfig; + whisper_full_params WhisperParams; + // Recognition outputs. + std::string Outputs; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::Whisper diff --git a/plugins/wasi_nn/wasinnbase.h b/plugins/wasi_nn/wasinnbase.h new file mode 100644 index 00000000..1f5f70d8 --- /dev/null +++ b/plugins/wasi_nn/wasinnbase.h @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinnenv.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasiNN : public Runtime::HostFunction { +public: + WasiNN(WASINN::WasiNNEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + static constexpr uint32_t castErrNo(WASINN::ErrNo E) noexcept { + return static_cast(E); + } + + WASINN::WasiNNEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp new file mode 100644 index 00000000..6f4a1c2e --- /dev/null +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinnenv.h" +#include "wasinnmodule.h" +#include "wasinntypes.h" + +#include + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +#include +#endif + +using namespace std::literals; + +namespace WasmEdge { +namespace Host { + +namespace WASINN { + +namespace { +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasiNNModule; +} + +std::map BackendMap = { + {"openvino"sv, Backend::OpenVINO}, + {"onnx"sv, Backend::ONNX}, + {"tensorflow"sv, Backend::Tensorflow}, + {"pytorch"sv, Backend::PyTorch}, + {"pytorchaoti"sv, Backend::PyTorch}, + {"tensorflowlite"sv, Backend::TensorflowLite}, + {"autodetect"sv, Backend::Autodetect}, + {"ggml"sv, Backend::GGML}, + {"neuralspeed"sv, Backend::NeuralSpeed}, + {"whisper"sv, Backend::Whisper}, + {"mlx"sv, Backend::MLX}, + {"piper"sv, Backend::Piper}, + {"chattts"sv, Backend::ChatTTS}, + {"openvinogenai"sv, Backend::OpenVINOGenAI}, + {"bitnet"sv, Backend::BitNet}}; + +std::map DeviceMap = {{"cpu"sv, Device::CPU}, + {"gpu"sv, Device::GPU}, + {"tpu"sv, Device::TPU}, + {"auto"sv, Device::AUTO}}; + +bool load(const std::filesystem::path &Path, std::vector &Data) { + std::ifstream File(Path, std::ios::binary); + if (!File.is_open()) { + spdlog::error("[WASI-NN] Preload model fail."sv); + return false; + } + File.seekg(0, std::ios::end); + std::streampos FileSize = File.tellg(); + File.seekg(0, std::ios::beg); + Data.resize(FileSize); + File.read(reinterpret_cast(Data.data()), FileSize); + File.close(); + return true; +} +} // namespace + +WasiNNEnvironment::WasiNNEnvironment() noexcept { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (getenv("_WASI_NN_RPCSERVER") == nullptr) { + // RPC client mode + auto URI = NNRPCURI.value(); + if (!URI.empty()) { + std::string_view UnixPrefix = "unix://"; + if (URI.substr(0, UnixPrefix.length()) != UnixPrefix) { + spdlog::warn("[WASI-NN] Expected \"unix://...\", got \"{}\""sv, URI); + } + auto Cred = grpc::InsecureChannelCredentials(); // safe for unix://... + NNRPCChannel = grpc::CreateChannel(URI, Cred); + if (NNModels.value().size() > 0) { + spdlog::warn( + "[WASI-NN] nn-preload has to be specified on the RPC server side, not on the client side"sv); + } + return; + } + } +#endif + // Preload NN Models + for (const auto &M : NNModels.value()) { + std::istringstream ISS(M); + const char Delimiter = ':'; + std::string Name; + std::string Encode; + std::string Target; + std::vector Paths; + std::getline(ISS, Name, Delimiter); + std::getline(ISS, Encode, Delimiter); + std::getline(ISS, Target, Delimiter); + std::string Path; + while (std::getline(ISS, Path, Delimiter)) { + Paths.push_back(Path); + } + std::vector> Models; + Models.reserve(Paths.size()); + std::transform(Encode.begin(), Encode.end(), Encode.begin(), + [](unsigned char C) { + return static_cast(std::tolower(C)); + }); + std::transform(Target.begin(), Target.end(), Target.begin(), + [](unsigned char C) { + return static_cast(std::tolower(C)); + }); + auto Backend = BackendMap.find(Encode); + auto Device = DeviceMap.find(Target); + if (Backend != BackendMap.end() && Device != DeviceMap.end()) { + if (Backend->second == Backend::GGML || + Backend->second == Backend::BitNet || + (Backend->second == Backend::PyTorch && Encode == "pytorchaoti"sv)) { + // In GGML, we only support loading one model from nn-preload + // config. To handle paths on Windows that contains `:` in the + // path, we combine the Paths into a single string separated by + // `:`. + std::string P; + for (const std::string &PathSegment : Paths) { + P += PathSegment; + if (PathSegment != Paths.back()) { + P += ":"; + } + } + // We write model path to model data to avoid file IO in + // llama.cpp. + std::string ModelPath = "preload:" + P; + std::vector ModelPathData(ModelPath.begin(), ModelPath.end()); + Models.push_back(std::move(ModelPathData)); + } else { + for (const std::string &P : Paths) { + std::vector Model; + if (load(std::filesystem::u8path(P), Model)) { + Models.push_back(std::move(Model)); + } + } + } + RawMdMap.emplace(Name, std::make_tuple(std::move(Models), Backend->second, + Device->second)); + } else { + spdlog::error( + "[WASI-NN] Preload Model's Backend or Device is Not Support."sv); + } + } + NNGraph.reserve(16U); + NNContext.reserve(16U); +} + +PO::List WasiNNEnvironment::NNModels( + PO::Description( + "Allow preload models from wasinn plugin. Each NN model can be specified as --nn-preload `COMMAND`."sv), + PO::MetaVar("COMMANDS"sv)); + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +PO::Option WasiNNEnvironment::NNRPCURI( + PO::Description("Specify NN RPC URI to connect (\"unix://...\")"sv), + PO::MetaVar("URI"sv), PO::DefaultValue(std::string(""))); +#endif + +namespace { +void addOptions(const Plugin::Plugin::PluginDescriptor *, + PO::ArgumentParser &Parser) noexcept { + Parser.add_option("nn-preload"sv, WasiNNEnvironment::NNModels); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (getenv("_WASI_NN_RPCSERVER") == nullptr) { + // RPC client mode + Parser.add_option("nn-rpc-uri"sv, WasiNNEnvironment::NNRPCURI); + } +#endif +} + +static Plugin::PluginModule::ModuleDescriptor MD[] = { + { + /* Name */ "wasi_nn", + /* Description */ "", + /* Create */ create, + }, +}; + +Plugin::Plugin::PluginDescriptor Descriptor{ + /* Name */ "wasi_nn", + /* Description */ "", + /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, + /* Version */ + {WASI_NN_VERSION_MAJOR, WASI_NN_VERSION_MINOR, WASI_NN_VERSION_PATCH, 0}, + /* ModuleCount */ 1, + /* ModuleDescriptions */ MD, + /* ComponentCount */ 0, + /* ComponentDescriptions */ nullptr, + /* AddOptions */ addOptions, +}; +} // namespace + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace WASINN + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h new file mode 100644 index 00000000..d9123069 --- /dev/null +++ b/plugins/wasi_nn/wasinnenv.h @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "GGML/core/ggml_core.h" +#include "wasinn_bitnet.h" +#include "wasinn_chattts.h" +#include "wasinn_mlx.h" +#include "wasinn_neuralspeed.h" +#include "wasinn_onnx.h" +#include "wasinn_openvino.h" +#include "wasinn_openvino_genai.h" +#include "wasinn_piper.h" +#include "wasinn_tf.h" +#include "wasinn_tfl.h" +#include "wasinn_torch.h" +#include "wasinn_whisper.h" +#include "wasinntypes.h" + +#include "host/wasi/environ.h" + +#include "common/spdlog.h" +#include "host/wasi/wasimodule.h" +#include "plugin/plugin.h" +#include "runtime/callingframe.h" + +#include +#include +#include +#include +#include + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +#include +#include +#endif + +namespace WasmEdge { +namespace Host { +namespace WASINN { + +namespace detail { +template struct VariantIndex; + +template +struct VariantIndex> + : std::integral_constant {}; + +template +struct VariantIndex> + : std::integral_constant< + std::size_t, VariantIndex>::value + 1> {}; + +template +inline constexpr std::size_t VariantIndexV = VariantIndex::value; + +template struct BackendTrait; +#define EACH(B) \ + template <> struct BackendTrait { \ + using Graph = B::Graph; \ + using Context = B::Context; \ + }; +FOR_EACH_BACKEND(EACH) +#undef EACH + +template using BackendGraphT = typename BackendTrait::Graph; +template using BackendContextT = typename BackendTrait::Context; +} // namespace detail + +class Graph { +public: + Graph() = delete; + Graph(Backend BE) noexcept : Impl(std::in_place_type_t()) { + init(BE); + } + + Backend getBackend() const noexcept { + using V = std::decay_t; + switch (Impl.index()) { +#define EACH(B) \ + case detail::VariantIndexV: \ + return Backend::B; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); + } + } + + template auto &get() noexcept { + return *std::get_if>(&Impl); + } + template const auto &get() const noexcept { + return *std::get_if>(&Impl); + } + template auto &get() noexcept { return *std::get_if(&Impl); } + template const auto &get() const noexcept { + return *std::get_if(&Impl); + } + + void init(Backend BE) noexcept { + switch (BE) { +#define EACH(B) \ + case Backend::B: \ + Impl.emplace(); \ + break; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); + } + Stat = Status::Uninitialized; + CtxCnt = 0; + } + void reset() noexcept { + Impl = std::monostate{}; + Stat = Status::Uninitialized; + CtxCnt = 0; + } + void increaseContext() noexcept { CtxCnt++; } + void decreaseContext() noexcept { + assuming(CtxCnt > 0); + CtxCnt--; + } + uint32_t getContextCount() const noexcept { return CtxCnt; } + bool isFinalized() const noexcept { + return Stat == Status::Uninitialized || Stat == Status::Finalized; + } + bool isReady() const noexcept { return Stat == Status::Ready; } + void setInvalid() noexcept { Stat = Status::Invalid; } + void setFinalized() noexcept { Stat = Status::Finalized; } + void setReady() noexcept { Stat = Status::Ready; } + +private: + std::variant< +#define EACH(B) B::Graph, + FOR_EACH_BACKEND(EACH) +#undef EACH + std::monostate> + Impl; + // Graph status. + // Uninitialized: A new graph in monostate. + // Invalid: The graph failed to load in set_input with metadata. It can be + // reloaded with new metadata in set_input. + // Finalized: The graph is being deleted, but there are linked contexts. + // This graph ID will be released once the contexts are + // deleted. + // Ready: This graph can be used to create a context. + enum class Status : uint8_t { Uninitialized, Invalid, Finalized, Ready }; + Status Stat; + uint32_t CtxCnt; +}; + +class Context { +public: + Context() = delete; + Context(uint32_t GId, Graph &G) noexcept + : Impl(std::in_place_type_t()) { + init(GId, G); + } + + Backend getBackend() const noexcept { + using V = std::decay_t; + switch (Impl.index()) { +#define EACH(B) \ + case detail::VariantIndexV: \ + return Backend::B; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); + } + } + + template auto &get() noexcept { + return *std::get_if>(&Impl); + } + template const auto &get() const noexcept { + return *std::get_if>(&Impl); + } + template auto &get() noexcept { return *std::get_if(&Impl); } + template const auto &get() const noexcept { + return *std::get_if(&Impl); + } + + void init(uint32_t GId, Graph &G) noexcept { + switch (G.getBackend()) { +#define EACH(B) \ + case Backend::B: \ + Impl.emplace(GId, G.get()); \ + break; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); + } + Stat = Status::Uninitialized; + GraphId = GId; + } + void reset() noexcept { + Impl = std::monostate{}; + Stat = Status::Uninitialized; + GraphId = 0; + } + uint32_t getGraphId() const noexcept { + return static_cast(GraphId); + } + bool isReady() const noexcept { return Stat == Status::Ready; } + void setReady() noexcept { Stat = Status::Ready; } + +private: + std::variant< +#define EACH(B) B::Context, + FOR_EACH_BACKEND(EACH) +#undef EACH + std::monostate> + Impl; + // Context status. + // Uninitialized: A new context in monostate. + // Ready: This context can be used to infer. + enum class Status : uint8_t { Uninitialized, Ready }; + Status Stat; + uint32_t GraphId; +}; + +struct WasiNNEnvironment : +#define EACH(B) B::Environ, + FOR_EACH_BACKEND(EACH) +#undef EACH + std::monostate { + + using Callback = std::function( + WASINN::WasiNNEnvironment &, Span>, WASINN::Backend, + WASINN::Device, uint32_t &)>; + + WasiNNEnvironment() noexcept; + + bool mdGet(std::string Name, uint32_t &GraphId) noexcept { + std::shared_lock Lock(MdMutex); + if (auto It = MdMap.find(Name); It != MdMap.end()) { + GraphId = EndianValue(static_cast(It->second)).le(); + return true; + } + return false; + } + + void mdRemoveById(uint32_t GraphId) noexcept { + std::unique_lock Lock(MdMutex); + for (auto It = MdMap.begin(); It != MdMap.end();) { + if (It->second == static_cast(GraphId)) { + It = MdMap.erase(It); + } else { + ++It; + } + } + } + + Expect + mdBuild(std::string Name, uint32_t &GraphId, Callback Load, + std::vector Config = std::vector()) noexcept { + std::unique_lock Lock(MdMutex); + auto It = RawMdMap.find(Name); + if (It != RawMdMap.end()) { + auto RawMd = std::get<0>(It->second); + std::vector> Builders; + Builders.reserve(RawMd.size()); + for (auto &Builder : RawMd) { + Builders.emplace_back(Builder); + } + // Add config to the end of Builders if exists. + if (Config.size() > 0) { + Builders.emplace_back(Config); + } + auto Result = Load(*this, Builders, std::get<1>(It->second), + std::get<2>(It->second), GraphId); + if (Result.has_value()) { + MdMap[Name] = GraphId; + } + return Result; + } + return WASINN::ErrNo::NotFound; + } + + uint32_t newGraph(Backend BE) noexcept { + std::unique_lock Lock(GraphMutex); + uint32_t ID = static_cast(NNGraph.size()); + if (NNGraphRecycle.empty()) { + NNGraph.emplace_back(BE); + } else { + ID = *NNGraphRecycle.begin(); + NNGraph[ID].init(BE); + NNGraphRecycle.erase(ID); + } + return ID; + } + + uint32_t newContext(uint32_t GId, Graph &G) noexcept { + std::unique_lock Lock(GraphMutex); + assuming(NNGraph.size() > GId); + // TODO: Merge GId into graph class. + uint32_t ID = static_cast(NNContext.size()); + if (NNContextRecycle.empty()) { + NNContext.emplace_back(GId, G); + } else { + ID = *NNContextRecycle.begin(); + NNContext[ID].init(GId, G); + NNContextRecycle.erase(ID); + } + G.increaseContext(); + return ID; + } + + void deleteGraph(const uint32_t Id) noexcept { + // TODO: Add the deallocation callback. + std::unique_lock Lock(GraphMutex); + if (Id < NNGraph.size()) { + auto &G = NNGraph[Id]; + G.setFinalized(); + if (G.getContextCount() == 0) { + // All contexts are deleted. Release the graph ID. + if (Id == NNGraph.size() - 1) { + NNGraph.pop_back(); + } else { + G.reset(); + NNGraphRecycle.insert(Id); + } + } + } + } + + void deleteContext(const uint32_t Id) noexcept { + // TODO: Add the deallocation callback. + std::unique_lock Lock(GraphMutex); + if (Id < NNContext.size() && + NNContextRecycle.find(Id) == NNContextRecycle.end()) { + auto GId = NNContext[Id].getGraphId(); + auto &G = NNGraph[GId]; + G.decreaseContext(); + if (G.getContextCount() == 0 && G.isFinalized()) { + // All contexts are deleted. Release the graph ID. + if (GId == NNGraph.size() - 1) { + NNGraph.pop_back(); + } else { + G.reset(); + NNGraphRecycle.insert(GId); + } + } + if (Id == NNContext.size() - 1) { + NNContext.pop_back(); + } else { + NNContext[Id].reset(); + NNContextRecycle.insert(Id); + } + } + } + + // Md storage + mutable std::shared_mutex MdMutex; + std::unordered_map>, + Backend, Device>> + RawMdMap; + std::unordered_map MdMap; + + // Graph and context + mutable std::shared_mutex GraphMutex; + std::unordered_set NNGraphRecycle; + std::vector NNGraph; + std::unordered_set NNContextRecycle; + std::vector NNContext; + + // Preload model list + static PO::List NNModels; +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + static PO::Option NNRPCURI; // For RPC client mode + std::shared_ptr NNRPCChannel; +#endif + + const Host::WASI::Environ *getEnv() const noexcept { return Environ; } + void setEnviron(const Runtime::CallingFrame *CurrentFrame) noexcept { + auto *WasiModule = CurrentFrame->getWASIModule(); + if (WasiModule != nullptr) { + Environ = dynamic_cast(WasiModule) + ->getEnv(); + } + } + +private: + const Host::WASI::Environ *Environ = nullptr; +}; + +} // namespace WASINN +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp new file mode 100644 index 00000000..afa7b66a --- /dev/null +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -0,0 +1,752 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinnfunc.h" +#include "wasinnenv.h" + +#include "common/spdlog.h" + +#include +#include + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +#include "wasi_ephemeral_nn.grpc.pb.h" + +#include +#endif // #ifdef WASMEDGE_BUILD_WASI_NN_RPC + +namespace WasmEdge { +namespace Host { + +namespace { +inline void reportUnknownBackend(WASINN::Backend B) noexcept { + spdlog::error("[WASI-NN] Unknown backend {}."sv, static_cast(B)); +} +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Backend Backend, WASINN::Device Device, + uint32_t &GraphId) { + switch (Backend) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::load(Env, Builders, Device, GraphId); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; + } +} +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +WASINN::ErrNo metadataToErrNo( + const std::multimap &Metadata) { + if (Metadata.find("errno") != Metadata.end()) { + auto ErrNo = std::stoi(Metadata.find("errno")->second.data()); + return static_cast(ErrNo); + } + return WASINN::ErrNo::Success; +} +#endif // #ifdef WASMEDGE_BUILD_WASI_NN_RPC +} // namespace + +Expect +WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, + uint32_t BuilderLen, uint32_t RawEncoding, uint32_t Target, + uint32_t GraphIdPtr) { + Env.setEnviron(&Frame); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for Load + spdlog::error("[WASI-NN] RPC client is not implemented for Load"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + // Check the return value: GraphIdPtr should be valid. + uint32_t *GraphId = MemInst->getPointer(GraphIdPtr); + if (unlikely(GraphId == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return GraphID memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + // Get and check the device. + const auto Device = static_cast(Target); + switch (Device) { + case WASINN::Device::CPU: + case WASINN::Device::GPU: + case WASINN::Device::TPU: + break; + default: + spdlog::error("[WASI-NN] Unknown device {}."sv, Target); + return WASINN::ErrNo::InvalidArgument; + } + spdlog::debug("[WASI-NN] Using device: {}."sv, Device); + + // Builders' Layout: + // | builder-0 | builder-0 len | builder-1 | builder-1 len | ... + struct WasiBuilderPair { + uint32_t Ptr; + uint32_t Len; + }; + + const auto WasiBuilders = + MemInst->getSpan(BuilderPtr, BuilderLen); + if (unlikely(WasiBuilders.size() != BuilderLen)) { + spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + std::vector> Builders; + Builders.reserve(BuilderLen); + for (size_t I = 0; I < WasiBuilders.size(); ++I) { + const auto &WasiBuilder = WasiBuilders[I]; + auto Builder = MemInst->getSpan(EndianValue(WasiBuilder.Ptr).le(), + EndianValue(WasiBuilder.Len).le()); + if (unlikely(Builder.size() != EndianValue(WasiBuilder.Len).le())) { + spdlog::error("[WASI-NN] Failed when accessing the Builder[{}] memory."sv, + I); + return WASINN::ErrNo::InvalidArgument; + } + Builders.emplace_back(Builder); + } + auto Backend = static_cast(RawEncoding); + return load(Env, Builders, Backend, Device, *GraphId); +} + +Expect +WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen, uint32_t GraphIdPtr) { + Env.setEnviron(&Frame); + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + // Check the return value: GraphIdPtr should be valid. + uint32_t *GraphId = MemInst->getPointer(GraphIdPtr); + if (unlikely(GraphId == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return GraphID memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the model name. + auto Name = MemInst->getPointer(NamePtr); + if (unlikely(Name == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::Graph::NewStub(Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::LoadByNameRequest Req; + auto NameStrView = MemInst->getStringView(NamePtr, NameLen); + Req.set_name(NameStrView.data(), NameStrView.size()); + wasi_ephemeral_nn::LoadByNameResult Res; + auto Status = Stub->LoadByName(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + *GraphId = Res.graph_handle(); + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + // Get the model. + std::string ModelName(reinterpret_cast(Name), NameLen); + if (Env.mdGet(ModelName, *GraphId)) { + return WASINN::ErrNo::Success; + } else { + return Env.mdBuild(ModelName, *GraphId, load); + } +} + +Expect WasiNNLoadByNameWithConfig::bodyImpl( + const Runtime::CallingFrame &Frame, uint32_t NamePtr, uint32_t NameLen, + uint32_t ConfigPtr, uint32_t ConfigLen, uint32_t GraphIdPtr) { + Env.setEnviron(&Frame); + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + // Check the return value: GraphIdPtr should be valid. + auto GraphId = MemInst->getPointer(GraphIdPtr); + if (unlikely(GraphId == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return GraphID memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the model name. + auto Name = MemInst->getPointer(NamePtr); + if (unlikely(Name == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the model config. + auto Config = MemInst->getPointer(ConfigPtr); + if (unlikely(Config == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return Config memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::Graph::NewStub(Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::LoadByNameWithConfigRequest Req; + auto NameStrView = MemInst->getStringView(NamePtr, NameLen); + auto ConfigStrView = MemInst->getStringView(ConfigPtr, ConfigLen); + Req.set_name(NameStrView.data(), NameStrView.size()); + Req.set_config(ConfigStrView.data(), ConfigStrView.size()); + wasi_ephemeral_nn::LoadByNameWithConfigResult Res; + auto Status = Stub->LoadByNameWithConfig(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + *GraphId = Res.graph_handle(); + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + // Get the model. + std::string ModelName(reinterpret_cast(Name), NameLen); + std::vector ModelConfig(reinterpret_cast(Config), + reinterpret_cast(Config) + + ConfigLen); + if (Env.mdGet(ModelName, *GraphId)) { + return WASINN::ErrNo::Success; + } else { + return Env.mdBuild(ModelName, *GraphId, load, ModelConfig); + } +} + +Expect +WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId, uint32_t ContextPtr) { + Env.setEnviron(&Frame); + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + // Check the return value: Context should be valid. + uint32_t *Context = MemInst->getPointer(ContextPtr); + if (unlikely(Context == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Context memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphResource::NewStub(Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::InitExecutionContextRequest Req; + Req.set_resource_handle(GraphId); + wasi_ephemeral_nn::InitExecutionContextResult Res; + auto Status = Stub->InitExecutionContext(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + *Context = Res.ctx_handle(); + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNGraph.size() <= GraphId || Env.NNGraph[GraphId].isFinalized()) { + spdlog::error("[WASI-NN] init_execution_context: Graph ID {} does not " + "exist or is unloaded."sv, + GraphId); + return WASINN::ErrNo::InvalidArgument; + } + if (!Env.NNGraph[GraphId].isReady()) { + spdlog::error("[WASI-NN] init_execution_context: Graph ID {} is invalid. " + "Please reload or unload this graph."sv, + GraphId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (const auto Backend = Env.NNGraph[GraphId].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::initExecCtx(Env, GraphId, *Context); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; + } +} + +Expect +WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId, + uint32_t Index, uint32_t TensorPtr) { + Env.setEnviron(&Frame); + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + // Tensor's Layout: + // | dim buf | dim buf len | rtype | data buf | data buf len | + struct WasiTensorData { + uint32_t DimensionPtr; + uint32_t DimensionLen; + uint32_t RType; + uint32_t TensorPtr; + uint32_t TensorLen; + }; + // Get the tensor. + auto *WasiTensor = MemInst->getPointer(TensorPtr); + if (unlikely(WasiTensor == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + WASINN::TensorData Tensor; + Tensor.Dimension = + MemInst->getSpan(EndianValue(WasiTensor->DimensionPtr).le(), + EndianValue(WasiTensor->DimensionLen).le()); + if (unlikely(Tensor.Dimension.size() != + EndianValue(WasiTensor->DimensionLen).le())) { + spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + Tensor.Tensor = + MemInst->getSpan(EndianValue(WasiTensor->TensorPtr).le(), + EndianValue(WasiTensor->TensorLen).le()); + if (unlikely(Tensor.Tensor.size() != + EndianValue(WasiTensor->TensorLen).le())) { + spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + switch (const auto RType = static_cast( + EndianValue(WasiTensor->RType).le())) { + case WASINN::TensorType::F16: + case WASINN::TensorType::F32: + case WASINN::TensorType::U8: + case WASINN::TensorType::I32: + Tensor.RType = RType; + break; + default: + spdlog::error("[WASI-NN] Unknown tensor type {}."sv, + static_cast(RType)); + return WASINN::ErrNo::InvalidArgument; + } + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::SetInputRequest Req; + Req.set_resource_handle(ContextId); + Req.set_index(Index); + wasi_ephemeral_nn::Tensor RPCTensor; + RPCTensor.mutable_dimensions()->Add(Tensor.Dimension.begin(), + Tensor.Dimension.end()); + RPCTensor.set_ty(wasi_ephemeral_nn::TensorType(Tensor.RType)); + RPCTensor.set_data(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + *Req.mutable_tensor() = RPCTensor; + google::protobuf::Empty Res; + auto Status = Stub->SetInput(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] set_input: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (const auto Backend = Env.NNContext[ContextId].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::setInput(Env, ContextId, Index, Tensor); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; + } +} + +Expect +WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ContextId, uint32_t Index, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { + Env.setEnviron(&Frame); + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto OutBuffer = + MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); + if (unlikely(OutBuffer.data() == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the Output Buffer memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + uint32_t *BytesWritten = MemInst->getPointer(BytesWrittenPtr); + if (unlikely(BytesWritten == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::GetOutputRequest Req; + Req.set_resource_handle(ContextId); + Req.set_index(Index); + wasi_ephemeral_nn::GetOutputResult Res; + auto Status = Stub->GetOutput(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + uint32_t BytesWrittenVal = + std::min(static_cast(Res.data().size()), OutBufferMaxSize); + std::copy_n(Res.data().begin(), BytesWrittenVal, OutBuffer.begin()); + *BytesWritten = BytesWrittenVal; + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] get_output: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (const auto Backend = Env.NNContext[ContextId].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::getOutput(Env, ContextId, Index, OutBuffer, \ + *BytesWritten); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; + } +} + +Expect WasiNNGetOutputSingle::bodyImpl( + const Runtime::CallingFrame &Frame, uint32_t ContextId, uint32_t Index, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { + Env.setEnviron(&Frame); + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto OutBuffer = + MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); + if (unlikely(OutBuffer.data() == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the Output Buffer memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + uint32_t *BytesWritten = MemInst->getPointer(BytesWrittenPtr); + if (unlikely(BytesWritten == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::GetOutputRequest Req; + Req.set_resource_handle(ContextId); + Req.set_index(Index); + wasi_ephemeral_nn::GetOutputResult Res; + auto Status = Stub->GetOutputSingle(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + uint32_t BytesWrittenVal = + std::min(static_cast(Res.data().size()), OutBufferMaxSize); + std::copy_n(Res.data().begin(), BytesWrittenVal, OutBuffer.begin()); + *BytesWritten = BytesWrittenVal; + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error( + "[WASI-NN] get_output_single: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNContext[ContextId].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::getOutputSingle(Env, ContextId, Index, OutBuffer, + *BytesWritten); + case WASINN::Backend::BitNet: + return WASINN::BitNet::getOutputSingle(Env, ContextId, Index, OutBuffer, + *BytesWritten); + default: + spdlog::error( + "[WASI-NN] get_output_single: Only GGML and BitNet backend supports "sv + "get_output_single."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + +Expect +WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ContextId) { + Env.setEnviron(&Frame); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::ComputeRequest Req; + Req.set_resource_handle(ContextId); + google::protobuf::Empty Res; + auto Status = Stub->Compute(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] compute: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + auto GraphId = Env.NNContext[ContextId].getGraphId(); + assuming(Env.NNGraph.size() > GraphId); + if (!Env.NNGraph[GraphId].isReady()) { + spdlog::error("[WASI-NN] compute: Graph ID {} for context ID {} does not "sv + "exist or has released."sv, + GraphId, ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (const auto Backend = Env.NNContext[ContextId].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::compute(Env, ContextId); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; + } +} + +Expect +WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ContextId) { + Env.setEnviron(&Frame); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::ComputeRequest Req; + Req.set_resource_handle(ContextId); + google::protobuf::Empty Res; + auto Status = Stub->ComputeSingle(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] compute_single: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + auto GraphId = Env.NNContext[ContextId].getGraphId(); + assuming(Env.NNGraph.size() > GraphId); + if (!Env.NNGraph[GraphId].isReady()) { + spdlog::error("[WASI-NN] compute_single: Graph ID {} for context ID {} "sv + "does not exist or has released."sv, + GraphId, ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNContext[ContextId].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::computeSingle(Env, ContextId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::computeSingle(Env, ContextId); + default: + spdlog::error( + "[WASI-NN] compute_single: Only GGML and BitNet backend supports "sv + "compute_single."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + +Expect +WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ContextId) { + Env.setEnviron(&Frame); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::FiniSingleRequest Req; + Req.set_resource_handle(ContextId); + google::protobuf::Empty Res; + auto Status = Stub->FiniSingle(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] fini_single: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNContext[ContextId].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::finiSingle(Env, ContextId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::finiSingle(Env, ContextId); + default: + spdlog::error( + "[WASI-NN] fini_single: Only GGML and BitNet backend supports fini_single."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + +Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId) { + Env.setEnviron(&Frame); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for unload + spdlog::error("[WASI-NN] RPC client is not implemented for unload"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNGraph.size() <= GraphId) { + spdlog::error("[WASI-NN] unload: GraphId {} does not exist."sv, GraphId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNGraph[GraphId].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::unload(Env, GraphId); + case WASINN::Backend::Whisper: + return WASINN::Whisper::unload(Env, GraphId); + case WASINN::Backend::ChatTTS: + return WASINN::ChatTTS::unload(Env, GraphId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::unload(Env, GraphId); + default: + spdlog::error("[WASI-NN] unload: Only GGML, Whisper, ChatTTS and BitNet "sv + "backends support unload."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + +Expect +WasiNNFinalizeExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ContextId) { + Env.setEnviron(&Frame); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for finalize_execution_context + spdlog::error("[WASI-NN] RPC client is not implemented for "sv + "finalize_execution_context"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= ContextId) { + spdlog::error( + "[WASI-NN] finalize_execution_context: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNContext[ContextId].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::finalizeExecCtx(Env, ContextId); + case WASINN::Backend::Whisper: + return WASINN::Whisper::finalizeExecCtx(Env, ContextId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::finalizeExecCtx(Env, ContextId); + default: + spdlog::error( + "[WASI-NN] finalize_execution_context: Only GGML, BitNet and "sv + "Whisper backends support finalize_execution_context."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h new file mode 100644 index 00000000..769955dc --- /dev/null +++ b/plugins/wasi_nn/wasinnfunc.h @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinnbase.h" + +#include "runtime/callingframe.h" + +#include + +namespace WasmEdge { +namespace Host { + +class WasiNNLoad : public WasiNN { +public: + WasiNNLoad(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, + uint32_t BuilderLen, uint32_t Encoding, uint32_t Target, + uint32_t GraphIdPtr) { + return bodyImpl(Frame, BuilderPtr, BuilderLen, Encoding, Target, GraphIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &, + uint32_t BuilderPtr, uint32_t BuilderLen, + uint32_t Encoding, uint32_t Target, + uint32_t GraphIdPtr); +}; + +class WasiNNLoadByName : public WasiNN { +public: + WasiNNLoadByName(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen, uint32_t GraphIdPtr) { + return bodyImpl(Frame, NamePtr, NameLen, GraphIdPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &, + uint32_t NamePtr, uint32_t NameLen, + uint32_t GraphIdPtr); +}; + +class WasiNNLoadByNameWithConfig : public WasiNN { +public: + WasiNNLoadByNameWithConfig(WASINN::WasiNNEnvironment &HostEnv) + : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen, uint32_t ConfigPtr, + uint32_t ConfigLen, uint32_t GraphIdPtr) { + return bodyImpl(Frame, NamePtr, NameLen, ConfigPtr, ConfigLen, GraphIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &, + uint32_t NamePtr, uint32_t NameLen, + uint32_t ConfigPtr, uint32_t ConfigLen, + uint32_t GraphIdPtr); +}; + +class WasiNNInitExecCtx : public WasiNN { +public: + WasiNNInitExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GraphId, + uint32_t ContextPtr) { + return bodyImpl(Frame, GraphId, ContextPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId, uint32_t ContextPtr); +}; + +class WasiNNSetInput : public WasiNN { +public: + WasiNNSetInput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index, uint32_t TensorPtr) { + return bodyImpl(Frame, Context, Index, TensorPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context, uint32_t Index, + uint32_t TensorPtr); +}; + +class WasiNNGetOutput : public WasiNN { +public: + WasiNNGetOutput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + return bodyImpl(Frame, Context, Index, OutBufferPtr, OutBufferMaxSize, + BytesWrittenPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context, uint32_t Index, + uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); +}; + +class WasiNNGetOutputSingle : public WasiNN { +public: + WasiNNGetOutputSingle(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + return bodyImpl(Frame, Context, Index, OutBufferPtr, OutBufferMaxSize, + BytesWrittenPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context, uint32_t Index, + uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); +}; + +class WasiNNCompute : public WasiNN { +public: + WasiNNCompute(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); +}; + +class WasiNNComputeSingle : public WasiNN { +public: + WasiNNComputeSingle(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); +}; + +class WasiNNFiniSingle : public WasiNN { +public: + WasiNNFiniSingle(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); +}; + +class WasiNNUnload : public WasiNN { +public: + WasiNNUnload(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GraphId) { + return bodyImpl(Frame, GraphId).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId); +}; + +class WasiNNFinalizeExecCtx : public WasiNN { +public: + WasiNNFinalizeExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp new file mode 100644 index 00000000..33c3d5b1 --- /dev/null +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinnmodule.h" +#include "wasinnfunc.h" + +namespace WasmEdge { +namespace Host { + +WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { + addHostFunc("load", std::make_unique(Env)); + addHostFunc("load_by_name", std::make_unique(Env)); + addHostFunc("load_by_name_with_config", + std::make_unique(Env)); + addHostFunc("init_execution_context", + std::make_unique(Env)); + addHostFunc("set_input", std::make_unique(Env)); + addHostFunc("get_output", std::make_unique(Env)); + addHostFunc("get_output_single", + std::make_unique(Env)); + addHostFunc("compute", std::make_unique(Env)); + addHostFunc("compute_single", std::make_unique(Env)); + addHostFunc("fini_single", std::make_unique(Env)); + addHostFunc("unload", std::make_unique(Env)); + addHostFunc("finalize_execution_context", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.h b/plugins/wasi_nn/wasinnmodule.h new file mode 100644 index 00000000..3a428eec --- /dev/null +++ b/plugins/wasi_nn/wasinnmodule.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinnenv.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiNNModule : public Runtime::Instance::ModuleInstance { +public: + WasiNNModule(); + + WASINN::WasiNNEnvironment &getEnv() { return Env; } + +private: + WASINN::WasiNNEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinntypes.h b/plugins/wasi_nn/wasinntypes.h new file mode 100644 index 00000000..d82873d3 --- /dev/null +++ b/plugins/wasi_nn/wasinntypes.h @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "common/span.h" +#include "common/spdlog.h" + +#include + +namespace WasmEdge::Host::WASINN { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. + UnsupportedOperation = 6, // Unsupported Operation. + TooLarge = 7, // Too Large. + NotFound = 8, // Not Found. + EndOfSequence = 100, // End of Sequence Found. + ContextFull = 101, // Context Full. + PromptTooLong = 102, // Prompt Too Long. + ModelNotFound = 103, // Model Not Found. +}; + +enum class TensorType : uint8_t { + F16 = 0, + F32 = 1, + F64 = 2, + U8 = 3, + I32 = 4, + I64 = 5 +}; + +enum class Device : uint32_t { CPU = 0, GPU = 1, TPU = 2, AUTO = 3 }; + +enum class Backend : uint8_t { + OpenVINO = 0, + ONNX = 1, + Tensorflow = 2, + PyTorch = 3, + TensorflowLite = 4, + Autodetect = 5, + GGML = 6, + NeuralSpeed = 7, + Whisper = 9, + MLX = 10, + Piper = 11, + ChatTTS = 12, + OpenVINOGenAI = 13, + BitNet = 14, +}; + +#define FOR_EACH_BACKEND(F) \ + F(OpenVINO) \ + F(ONNX) \ + F(Tensorflow) \ + F(PyTorch) \ + F(TensorflowLite) \ + F(GGML) \ + F(NeuralSpeed) \ + F(Whisper) \ + F(Piper) \ + F(ChatTTS) \ + F(MLX) \ + F(OpenVINOGenAI) \ + F(BitNet) + +struct TensorData { + Span Dimension; + WASINN::TensorType RType; + Span Tensor; +}; + +} // namespace WasmEdge::Host::WASINN + +template <> +struct fmt::formatter + : fmt::formatter { + fmt::format_context::iterator format(WasmEdge::Host::WASINN::TensorType RType, + fmt::format_context &Ctx) const { + using namespace std::literals; + std::string_view Name; + switch (RType) { + case WasmEdge::Host::WASINN::TensorType::F16: + Name = "F16"sv; + break; + case WasmEdge::Host::WASINN::TensorType::F32: + Name = "F32"sv; + break; + case WasmEdge::Host::WASINN::TensorType::F64: + Name = "F64"sv; + break; + case WasmEdge::Host::WASINN::TensorType::U8: + Name = "U8"sv; + break; + case WasmEdge::Host::WASINN::TensorType::I32: + Name = "I32"sv; + break; + case WasmEdge::Host::WASINN::TensorType::I64: + Name = "I64"sv; + break; + default: + Name = "Unknown"sv; + } + return fmt::formatter::format(Name, Ctx); + } +}; + +template <> +struct fmt::formatter + : fmt::formatter { + fmt::format_context::iterator format(WasmEdge::Host::WASINN::Device Target, + fmt::format_context &Ctx) const { + using namespace std::literals; + std::string_view Name; + switch (Target) { + case WasmEdge::Host::WASINN::Device::CPU: + Name = "CPU"sv; + break; + case WasmEdge::Host::WASINN::Device::GPU: + Name = "GPU"sv; + break; + case WasmEdge::Host::WASINN::Device::TPU: + Name = "TPU"sv; + break; + default: + Name = "Unknown"sv; + } + return fmt::formatter::format(Name, Ctx); + } +}; diff --git a/plugins/wasi_poll/CMakeLists.txt b/plugins/wasi_poll/CMakeLists.txt new file mode 100644 index 00000000..9c641135 --- /dev/null +++ b/plugins/wasi_poll/CMakeLists.txt @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasiPoll + SHARED + env.cpp + func.cpp + module.cpp +) + +target_compile_options(wasmedgePluginWasiPoll + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiPoll + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/thirdparty +) + +target_link_libraries(wasmedgePluginWasiPoll + PUBLIC +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiPoll + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiPoll + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasiPoll + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_poll/README.md b/plugins/wasi_poll/README.md new file mode 100644 index 00000000..4eae44f4 --- /dev/null +++ b/plugins/wasi_poll/README.md @@ -0,0 +1,3 @@ +# wasi_poll + +This is corresponding to [wasi-poll preview2](https://github.com/WebAssembly/wasi-poll). diff --git a/plugins/wasi_poll/base.h b/plugins/wasi_poll/base.h new file mode 100644 index 00000000..0a8f86fb --- /dev/null +++ b/plugins/wasi_poll/base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template +class WasiPoll : public Runtime::Component::HostFunction { +public: + WasiPoll(WasiPollEnvironment &HostEnv) + : Runtime::Component::HostFunction(), Env(HostEnv) {} + +protected: + WasiPollEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/env.cpp b/plugins/wasi_poll/env.cpp new file mode 100644 index 00000000..223839ad --- /dev/null +++ b/plugins/wasi_poll/env.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "env.h" +#include "module.h" + +namespace WasmEdge { +namespace Host { + +bool WasiPollEnvironment::isPollable(Pollable P) noexcept { + return PollableMap.at(P); +} +void WasiPollEnvironment::dropPollable(Pollable P) { PollableMap.erase(P); } +namespace { + +Runtime::Instance::ComponentInstance * +create(const Plugin::PluginComponent::ComponentDescriptor *) noexcept { + return new WasiPollModule(); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_poll", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 0, + .ModuleDescriptions = {}, + .ComponentCount = 1, + .ComponentDescriptions = + (Plugin::PluginComponent::ComponentDescriptor[]){ + { + .Name = "wasi:poll/poll", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/env.h b/plugins/wasi_poll/env.h new file mode 100644 index 00000000..4852f7b0 --- /dev/null +++ b/plugins/wasi_poll/env.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC +#pragma once + +#include "plugin/plugin.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +using Pollable = uint32_t; + +class WasiPollEnvironment { +public: + bool isPollable(Pollable P) noexcept; + void dropPollable(Pollable P); + +private: + std::unordered_map PollableMap; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/func.cpp b/plugins/wasi_poll/func.cpp new file mode 100644 index 00000000..8c622448 --- /dev/null +++ b/plugins/wasi_poll/func.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func.h" +#include "common/defines.h" +#include "common/errcode.h" + +namespace WasmEdge { +namespace Host { + +Expect Drop::body(Pollable P) { + Env.dropPollable(P); + return {}; +} + +Expect> PollOneoff::body(List In) { + std::vector Res; + for (auto P : In.collection()) { + Res.push_back(Env.isPollable(P)); + } + return List(std::move(Res)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/func.h b/plugins/wasi_poll/func.h new file mode 100644 index 00000000..0c5c4949 --- /dev/null +++ b/plugins/wasi_poll/func.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { + +class Drop : public WasiPoll { +public: + Drop(WasiPollEnvironment &HostEnv) : WasiPoll(HostEnv) {} + Expect body(Pollable This); +}; + +class PollOneoff : public WasiPoll { +public: + PollOneoff(WasiPollEnvironment &HostEnv) : WasiPoll(HostEnv) {} + Expect> body(List In); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/module.cpp b/plugins/wasi_poll/module.cpp new file mode 100644 index 00000000..b523f867 --- /dev/null +++ b/plugins/wasi_poll/module.cpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiPollModule::WasiPollModule() : ComponentInstance("wasi:poll/poll") { + addHostFunc("drop-pollable", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/module.h b/plugins/wasi_poll/module.h new file mode 100644 index 00000000..5c26930f --- /dev/null +++ b/plugins/wasi_poll/module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiPollModule : public Runtime::Instance::ComponentInstance { +public: + WasiPollModule(); + + WasiPollEnvironment &getEnv() { return Env; } + +private: + WasiPollEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/CMakeLists.txt b/plugins/wasm_bpf/CMakeLists.txt new file mode 100644 index 00000000..17ab9a80 --- /dev/null +++ b/plugins/wasm_bpf/CMakeLists.txt @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +# Try to get libbpf use the following order +# - PkgConfig +# - ${LIBBPF_SOURCE_DIR} +# - FetchContent + +option(WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF "Configure libbpf to use pkg-config for the build process. If enabled, the libbpf build script will utilize pkg-config to search for dependencies such as libz and libelf. If this feature is disabled, the headers and binaries for libz and libelf need to be correctly positioned." YES) + +message(STATUS "Trying to get libbpf..") +message(STATUS "Build libbpf with pkg-config: ${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}") +set(LIBBPF_FOUND FALSE) + +# A wrapper function to add libbpf located at a local path as a dependency +function(AddLibbpfAsExternal SOURCE_ROOT WITH_PKG_CONF) + include(ExternalProject) + set(LIBBPF_SO_PATH ${SOURCE_ROOT}/src/build/libbpf.so) + set(LIBBPF_INCLUDE_DIRS_LOCAL "${SOURCE_ROOT}/src/root/usr/include" "${SOURCE_ROOT}/include/uapi" "${SOURCE_ROOT}/include") + set(LIBBPF_INCLUDE_DIRS ${LIBBPF_INCLUDE_DIRS_LOCAL} PARENT_SCOPE) + + set(LIBBPF_LIBRARIES ${LIBBPF_SO_PATH} PARENT_SCOPE) + set(LIBBPF_LIBRARIES_STATIC ${SOURCE_ROOT}/src/build/libbpf.a PARENT_SCOPE) + + if(${WITH_PKG_CONF}) + set(PKGCONF_PREFIX "") + else() + set(PKGCONF_PREFIX "NO_PKG_CONFIG=1") + set(LIBBPF_DEP_LIBRARIES "elf" "z" PARENT_SCOPE) + endif() + message(STATUS "SOURCE_ROOT=${SOURCE_ROOT}") + ExternalProject_Add(libbpf + PREFIX libbpf + SOURCE_DIR ${SOURCE_ROOT} + CONFIGURE_COMMAND "mkdir" "build" "root" + BUILD_COMMAND "${PKGCONF_PREFIX}" "OBJDIR=${SOURCE_ROOT}/src/build" "DESTDIR=${SOURCE_ROOT}/src/root" "CFLAGS=-fPIC" "make" "-C" "${SOURCE_ROOT}/src" "install" + INSTALL_COMMAND "cp" "${LIBBPF_SO_PATH}" "${CMAKE_CURRENT_BINARY_DIR}/libbpf.so" + BUILD_IN_SOURCE TRUE + BUILD_BYPRODUCTS ${LIBBPF_SO_PATH} ${SOURCE_ROOT}/src/build/libbpf.a + ) + + set(LIBBPF_TARGET_NAME libbpf PARENT_SCOPE) +endfunction() + +# Try PkgConfig +if(NOT ${LIBBPF_FOUND}) + find_package(PkgConfig) + + if(PkgConfig_FOUND) + message(STATUS "Try to get libbpf through PkgConfig") + + # It will set LIBBPF_FOUND for us + pkg_check_modules(LIBBPF libbpf>=1.2.0 IMPORTED_TARGET) + set(LIBBPF_TARGET_NAME "PkgConfig::LIBBPF") + message(STATUS "LIBBPF_FOUND=${LIBBPF_FOUND}") + + if(${LIBBPF_FOUND}) + SET(LIBBPF_FOUND TRUE) + else() + SET(LIBBPF_FOUND FALSE) + endif() + + if(${LIBBPF_FOUND}) + message(STATUS "libbpf found using PkgConfig") + set(LIBBPF_SOURCE "pkgconf") + else() + message(STATUS "libbpf not found using pkgconfig") + endif() + else() + message(STATUS "PkgConfig not found") + endif() +endif() + +# Try LIBBPF_SOURCE_DIR +if(NOT ${LIBBPF_FOUND}) + message(STATUS "Try to get libbpf through the pre-defined LIBBPF_SOURCE_DIR") + + if(DEFINED LIBBPF_SOURCE_DIR) + AddLibbpfAsExternal(${LIBBPF_SOURCE_DIR} ${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}) + set(LIBBPF_FOUND TRUE) + message(STATUS "libbpf found using LIBBPF_SOURCE_DIR") + set(LIBBPF_SOURCE "sourcedir") + else() + message(STATUS "LIBBPF_SOURCE_DIR not defined") + endif() +endif() + +# Try FetchContent +if(NOT ${LIBBPF_FOUND}) + message(STATUS "Downloading libbpf source") + include(FetchContent) + FetchContent_Declare( + libbpf + GIT_REPOSITORY https://github.com/libbpf/libbpf + GIT_TAG 950cffc0366981d4e41b08f007b37bd6af931f25 + ) + FetchContent_MakeAvailable(libbpf) + message(STATUS "Downloading libbpf source - done") + + set(LIBBPF_DOWNLOAD_SOURCE_DIR "${libbpf_SOURCE_DIR}") + message(DEBUG "libbpf saved at: ${LIBBPF_DOWNLOAD_SOURCE_DIR}") + AddLibbpfAsExternal(${LIBBPF_DOWNLOAD_SOURCE_DIR} ${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}) + set(LIBBPF_FOUND TRUE) + set(LIBBPF_SOURCE "fetch-content") +endif() + +# If we cannot find libbpf. +if(NOT ${LIBBPF_FOUND}) + message(FATAL_ERROR "Could not find libbpf") +endif() + +if(${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}) + # Find the dependencies `libelf` and `libz` of libbpf + find_package(PkgConfig) + + pkg_check_modules(LIBBPF_DEP REQUIRED libelf zlib) + + message(STATUS "(From PKGCONF) LIBBPF_DEP_LIBRARIES=${LIBBPF_DEP_LIBRARIES}") +endif() + +message(STATUS "LIBBPF_INCLUDE_DIRS=${LIBBPF_INCLUDE_DIRS}") +message(STATUS "LIBBPF_LIBRARIES=${LIBBPF_LIBRARIES}") +message(STATUS "LIBBPF_TARGET_NAME=${LIBBPF_TARGET_NAME}") +message(STATUS "LIBBPF_LIBRARIES_STATIC=${LIBBPF_LIBRARIES_STATIC}") +message(STATUS "LIBBPF_SOURCE=${LIBBPF_SOURCE}") +message(STATUS "LIBBPF_DEP_LIBRARIES=${LIBBPF_DEP_LIBRARIES}") +message(STATUS "LIBBPF_DEP_LIBRARIES_STATIC=${LIBBPF_DEP_LIBRARIES_STATIC}") + +wasmedge_add_library(wasmedgePluginWasmBpf + SHARED + wasm-bpf-module.cpp + func-load-bpf-object.cpp + func-close-bpf-object.cpp + func-attach-bpf-program.cpp + func-bpf-buffer-poll.cpp + func-bpf-map-fd-by-name.cpp + func-bpf-map-operate.cpp + wasm-bpf.cpp + util.cpp +) + +add_dependencies(wasmedgePluginWasmBpf ${LIBBPF_TARGET_NAME}) + +if("${LIBBPF_SOURCE}" STREQUAL "pkgconf") + message(STATUS "Link libbpf dynamically") + target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES} ${LIBBPF_DEP_LIBRARIES}) +else() + # Link libbpf statically if we don't use pkgconf, because in this case libbpf is + # not installed systemwide. + message(STATUS "Link libbpf statically") + target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES_STATIC} ${LIBBPF_DEP_LIBRARIES}) +endif() + +target_include_directories(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_INCLUDE_DIRS}) + + +set_target_properties(wasmedgePluginWasmBpf PROPERTIES + CXX_STANDARD 17 + # Allow tests accessing plugin class functions + CXX_VISIBILITY_PRESET default + VISIBILITY_INLINES_HIDDEN OFF +) +# Fix undefined reference issue of `fmt::v9::formatter::format(WasmEdge::ErrInfo::InfoBoundary const&, fmt::v9::basic_format_context&) const` +target_link_libraries(wasmedgePluginWasmBpf + PUBLIC wasmedgeCommon +) + +target_compile_options(wasmedgePluginWasmBpf + PUBLIC + -DWASMEDGE_PLUGIN + -fPIC +) + +target_include_directories(wasmedgePluginWasmBpf + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${LIBBPF_INCLUDE_DIRS} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmBpf + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmBpf + PRIVATE + wasmedge_shared + ) +endif() diff --git a/plugins/wasm_bpf/README.md b/plugins/wasm_bpf/README.md new file mode 100644 index 00000000..95fbcdc5 --- /dev/null +++ b/plugins/wasm_bpf/README.md @@ -0,0 +1,111 @@ +# wasm_bpf Plugin + +This plugin adds six host functions that give Wasm applications access to eBPF. + +The six functions are listed here. All of them are in the `wasm_bpf` module when +this plugin is loaded. + +```c +/// look up a BPF map fd by name. +i32 wasm_bpf_map_fd_by_name(u64 obj, u32 name); +/// detach and close a BPF program. +i32 wasm_close_bpf_object(u64 obj); +/// CO-RE load a BPF object into the kernel. +u64 wasm_load_bpf_object(u32 obj_buf, u32 obj_buf_sz); +/// attach a BPF program to a kernel hook. +i32 wasm_attach_bpf_program(u64 obj, u32 name, + u32 attach_target); +/// poll a BPF buffer and call a Wasm callback indicated by sample_func. +/// the first call to this function will open and create a BPF buffer. +i32 wasm_bpf_buffer_poll(u64 program, i32 fd, u32 sample_func, + u32 ctx, u32 data, i32 max_size, + i32 timeout_ms); +/// perform lookup, update, delete, and get_next_key operations on a BPF map. +i32 wasm_bpf_map_operate(u64 fd, i32 cmd, u32 key, u32 value, + u32 next_key, u64 flags); +``` + +- `iXX` denotes signed integer with `XX` bits +- `uXX` denotes unsigned integer with `XX` bits + +## How to compile this plugin + +### Install dependencies + +See the for how to build `WasmEdge` from source. + +#### libbpf + +This plugin requires `libbpf >= 1.2` + +Follow [https://github.com/libbpf/libbpf#building-libbpf](https://github.com/libbpf/libbpf#building-libbpf) to build and install `libbpf`. + +### Build `wasm_bpf` plug-in + +Run the following commands at the root of the `WasmEdge` project: + +- Note: It's important to set `WASMEDGE_PLUGIN_WASM_BPF` to `TRUE` in the command line. This toggle controls the build of `wasm_bpf` plugin. + +```sh +cmake -DWASMEDGE_PLUGIN_WASM_BPF:BOOL=TRUE -B ./build -G "Unix Makefiles" +cmake --build ./build +``` + +## How to use this plugin + +You can either download the examples or build them by yourself. + +### Download the examples + +```sh +wget https://eunomia-bpf.github.io/wasm-bpf/examples/runqlat/runqlat.wasm +``` + +### build the examples + +Examples of wasm-bpf programs can be found in [wasm-bpf](https://github.com/eunomia-bpf/wasm-bpf/tree/main/examples) repo. You can build them by running the following commands: + +```sh +# install the wasi-sdk if you don't have it +wget https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-17/wasi-sdk-17.0-linux.tar.gz +tar -zxf wasi-sdk-17.0-linux.tar.gz +sudo mkdir -p /opt/wasi-sdk/ && sudo mv wasi-sdk-17.0/* /opt/wasi-sdk/ + +# build the examples +git clone https://github.com/eunomia-bpf/wasm-bpf +cd wasm-bpf/examples +git submodule update --init --recursive + +# for example, build the execve example +cd execve && make +``` + +All examples are: + +```console +$ ls +bootstrap execve go-execve go-lsm lsm opensnoop runqlat rust-bootstrap sockfilter sockops +``` + +### run the examples + +After building, you can find the plug-in `./build/plugins/wasm_bpf/libwasmedgePluginWasmBpf.so` and the WasmEdge CLI tool at `./build/tools/wasmedge/wasmedge`. + +Set `WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/` and run wasmedge: + +```console +# WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/ ./build/tools/wasmedge/wasmedge execve.wasm + +[289150] node -> /bin/sh -c which ps +[289151] sh -> which ps +[289152] node -> /bin/sh -c /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,c +[289153] sh -> /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,command= +[289154] node -> /bin/sh -c "/root/.vscode-server-insiders/bin/96a795cc +[289155] sh -> /root/.vscode-server-insiders/bin/96a795cc0 245632 245678 289148 +[289156] cpuUsage.sh -> sed -n s/^cpu\s//p /proc/stat +[289157] cpuUsage.sh -> cat /proc/245632/stat +[289158] cpuUsage.sh -> cat /proc/245678/stat +[289159] cpuUsage.sh -> cat /proc/289148/stat +[289160] cpuUsage.sh -> sleep 1 +^C +``` diff --git a/plugins/wasm_bpf/bpf-api.h b/plugins/wasm_bpf/bpf-api.h new file mode 100644 index 00000000..72140842 --- /dev/null +++ b/plugins/wasm_bpf/bpf-api.h @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "executor/executor.h" +#include "runtime/instance/module.h" +#include "wasmedge/wasmedge.h" + +#pragma GCC diagnostic push +#ifdef __clang__ +// Allow compilation using clang +#pragma GCC diagnostic warning "-Wextern-c-compat" + +#endif +extern "C" { +#include +#include +} + +#pragma GCC diagnostic pop + +#define POLL_TIMEOUT_MS 100 +#define PERF_BUFFER_PAGES 64 +#define DEBUG_LIBBPF_RUNTIME 0 +#define DEBUG_PRINT_BUFFER_SIZE 1024 + +namespace WasmEdge { +namespace Host { + +/// \brief Initialize libbpf callbacks. +void init_libbpf(void); + +typedef int32_t (*bpf_buffer_sample_fn)(void *ctx, void *data, size_t size); + +/// An abstraction of a BPF ring buffer or perf buffer. +/// See https://github.com/iovisor/bcc/blob/master/libbpf-tools/compat.c. +class bpf_buffer { +protected: + bpf_buffer_sample_fn fn; + WasmEdge_ExecutorContext *wasm_executor; + const WasmEdge_ModuleInstanceContext *wasm_module_instance; + uint32_t wasm_ctx; + uint32_t wasm_sample_function; + void *poll_data; + size_t max_poll_size; + uint32_t wasm_buf_ptr; + +public: + /// Sample callback that calls the Wasm handler indirectly. + int32_t bpf_buffer_sample(void *data, size_t size); + /// Check whether the BPF buffer is valid. + /// + /// A valid module instance should have only one table and a sample function. + bool is_valid() const; + /// Set the Wasm callback parameters. + void + set_callback_params(WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, + uint32_t sample_func, void *data, size_t max_size, + uint32_t ctx, uint32_t buf_ptr); + /// Poll the BPF buffer. + virtual int32_t bpf_buffer__poll(int32_t timeout_ms) = 0; + /// Open the BPF buffer map. + virtual int32_t bpf_buffer__open(int32_t fd, bpf_buffer_sample_fn sample_cb, + void *ctx) = 0; + virtual ~bpf_buffer() noexcept = default; +}; + +/// BPF program instance. +class wasm_bpf_program { + std::unique_ptr obj{nullptr, + bpf_object__close}; + std::unique_ptr buffer; + std::unordered_set> + links; + +public: + /// Find a BPF map fd by name. + int32_t bpf_map_fd_by_name(const char *name); + /// Load a BPF object from a buffer into the kernel. + int32_t load_bpf_object(const void *obj_buf, size_t obj_buf_sz); + /// Attach a BPF program to a target (e.g. a kernel function on a kprobe). + int32_t attach_bpf_program(const char *name, const char *attach_target); + /// Poll the BPF buffer to get data from the kernel. + int32_t bpf_buffer_poll(WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, + int32_t fd, int32_t sample_func, uint32_t ctx, + void *buffer_data, size_t max_size, + int32_t timeout_ms, uint32_t wasm_buf_ptr); + /// Get the BPF map pointer by fd. + bpf_map *map_ptr_by_fd(int32_t fd); +}; + +enum bpf_map_cmd { + _BPF_MAP_LOOKUP_ELEM = 1, + _BPF_MAP_UPDATE_ELEM, + _BPF_MAP_DELETE_ELEM, + _BPF_MAP_GET_NEXT_KEY, +}; + +/// Operate on a BPF map. +int32_t bpf_map_operate(int32_t fd, int32_t cmd, void *key, void *value, + void *next_key, uint64_t flags); +using handle_t = int64_t; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-attach-bpf-program.cpp b/plugins/wasm_bpf/func-attach-bpf-program.cpp new file mode 100644 index 00000000..27e2e0d8 --- /dev/null +++ b/plugins/wasm_bpf/func-attach-bpf-program.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func-attach-bpf-program.h" +#include "util.h" + +namespace WasmEdge { +namespace Host { + +Expect AttachBpfProgram::body(const Runtime::CallingFrame &Frame, + handle_t program, uint32_t name, + uint32_t attach_target) { + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::shared_lock lock(state->lock); + auto program_ptr = state->handles.find(program); + if (program_ptr == state->handles.end()) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const char *name_str = nullptr; + const char *attach_target_str = nullptr; + checkAndSetCstr(memory, name, name_str); + checkAndSetCstr(memory, attach_target, attach_target_str); + return program_ptr->second->attach_bpf_program(name_str, attach_target_str); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-attach-bpf-program.h b/plugins/wasm_bpf/func-attach-bpf-program.h new file mode 100644 index 00000000..a0f990b7 --- /dev/null +++ b/plugins/wasm_bpf/func-attach-bpf-program.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Attach a BPF program to the specified target. +class AttachBpfProgram + : public WasmEdge::Runtime::HostFunction { +public: + AttachBpfProgram(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program, uint32_t name, + uint32_t attach_target); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.cpp b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp new file mode 100644 index 00000000..7fb8ce03 --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func-bpf-buffer-poll.h" +#include "wasmedge/wasmedge.h" +#include + +namespace WasmEdge { +namespace Host { + +// Helper functions of context conversions. +inline const auto *toCallFrameCxt(const Runtime::CallingFrame *Cxt) noexcept { + return reinterpret_cast(Cxt); +} + +Expect BpfBufferPoll::body(const Runtime::CallingFrame &Frame, + handle_t program, int32_t fd, + int32_t sample_func, uint32_t ctx, + uint32_t data, uint32_t max_size, + int32_t timeout_ms) { + auto c_ctx = toCallFrameCxt(&Frame); + auto c_module = WasmEdge_CallingFrameGetModuleInstance(c_ctx); + auto c_executor = WasmEdge_CallingFrameGetExecutor(c_ctx); + if (unlikely(!c_ctx || !c_module || !c_executor)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto module_instance = Frame.getModule(); + if (unlikely(!module_instance)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::shared_lock lock(state->lock); + auto program_ptr = state->handles.find(program); + if (program_ptr == state->handles.end()) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto data_buf = memory->getSpan(data, max_size); + if (data_buf.size() != max_size) { + return Unexpect(ErrCode::Value::HostFuncError); + } + return program_ptr->second->bpf_buffer_poll(c_executor, c_module, fd, + sample_func, ctx, data_buf.data(), + max_size, timeout_ms, data); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.h b/plugins/wasm_bpf/func-bpf-buffer-poll.h new file mode 100644 index 00000000..63af33a7 --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.h @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// Perform a BPF buffer poll. If the map is not opened, it will be opened. +/// +/// \param fd the map fd for the BPF buffer. +/// \param sample_func callback function. When data is polled, it will be +/// invoked. +/// \param ctx user-customized variable. +/// \param data data buffer that will be used to store the polled data. +/// \param max_size how many bytes can be put in data. +/// \param timeout_ms how many milliseconds to wait. +/// +/// \return 0 on success, error code on failure. +class BpfBufferPoll : public WasmEdge::Runtime::HostFunction { +public: + BpfBufferPoll(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program, int32_t fd, + int32_t sample_func, uint32_t ctx, + uint32_t data, uint32_t max_size, + int32_t timeout_ms); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp b/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp new file mode 100644 index 00000000..6f6e7172 --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func-bpf-map-fd-by-name.h" +#include "util.h" +#include + +namespace WasmEdge { +namespace Host { + +Expect BpfMapFdByName::body(const Runtime::CallingFrame &Frame, + handle_t program, uint32_t name) { + const char *name_str = nullptr; + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + checkAndSetCstr(memory, name, name_str); + std::shared_lock guard(this->state->lock); + auto program_ptr = state->handles.find(program); + if (program_ptr == state->handles.end()) { + return Unexpect(ErrCode::Value::HostFuncError); + } + return program_ptr->second->bpf_map_fd_by_name(name_str); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-fd-by-name.h b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h new file mode 100644 index 00000000..e155d96c --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Look up a map fd by its name. +/// +/// Returns the map fd on success; other values indicate failure. +class BpfMapFdByName : public WasmEdge::Runtime::HostFunction { +public: + BpfMapFdByName(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program, uint32_t name); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-operate.cpp b/plugins/wasm_bpf/func-bpf-map-operate.cpp new file mode 100644 index 00000000..86ba7bb4 --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-operate.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func-bpf-map-operate.h" +#include "bpf-api.h" + +extern "C" { +#include +} + +using namespace std::literals; + +namespace WasmEdge { +namespace Host { + +#define ensure_memory_size(var, offset, expected_size) \ + const auto var##_span = memory->getSpan(offset, expected_size); \ + if (var##_span.size() != expected_size) \ + return Unexpect(ErrCode::Value::HostFuncError); \ + const auto var = var##_span.data(); + +Expect +BpfMapOperate::body(const WasmEdge::Runtime::CallingFrame &Frame, int32_t fd, + int32_t cmd, uint32_t key, uint32_t value, + uint32_t next_key, uint64_t flags) { + + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::shared_lock guard(this->state->lock); + bpf_map_info map_info; + memset(&map_info, 0, sizeof(map_info)); + uint32_t info_len = sizeof(map_info); + int32_t err; + if ((err = bpf_map_get_info_by_fd(fd, &map_info, &info_len)) != 0) { + spdlog::debug("[WasmEdge Wasm_bpf] Invalid map fd found: fd={},err={}"sv, + fd, err); + // Invalid map fd + return err; + } + auto key_size = map_info.key_size; + auto value_size = map_info.value_size; + + switch ((bpf_cmd)cmd) { + case BPF_MAP_GET_NEXT_KEY: { + ensure_memory_size(key_ptr, key, key_size); + ensure_memory_size(next_key_ptr, next_key, key_size); + return bpf_map_get_next_key(fd, key_ptr, next_key_ptr); + } + case BPF_MAP_LOOKUP_ELEM: { + ensure_memory_size(key_ptr, key, key_size); + ensure_memory_size(value_ptr, value, value_size); + return bpf_map_lookup_elem_flags(fd, key_ptr, value_ptr, flags); + } + case BPF_MAP_UPDATE_ELEM: { + ensure_memory_size(key_ptr, key, key_size); + ensure_memory_size(value_ptr, value, value_size); + return bpf_map_update_elem(fd, key_ptr, value_ptr, flags); + } + case BPF_MAP_DELETE_ELEM: { + ensure_memory_size(key_ptr, key, key_size); + return bpf_map_delete_elem_flags(fd, key_ptr, flags); + } + default: // More syscall commands can be allowed here + spdlog::debug("[WasmEdge Wasm_bpf] Invalid map operation"sv, cmd); + return -EINVAL; + } +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-operate.h b/plugins/wasm_bpf/func-bpf-map-operate.h new file mode 100644 index 00000000..c5779e2f --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-operate.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// Perform BPF map operations on a specified BPF map through a map fd. +/// +/// Returns zero on success; other values indicate errors. +class BpfMapOperate : public WasmEdge::Runtime::HostFunction { +public: + BpfMapOperate(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + int32_t fd, int32_t cmd, uint32_t key, + uint32_t value, uint32_t next_key, + uint64_t flags); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-close-bpf-object.cpp b/plugins/wasm_bpf/func-close-bpf-object.cpp new file mode 100644 index 00000000..52b428fb --- /dev/null +++ b/plugins/wasm_bpf/func-close-bpf-object.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func-close-bpf-object.h" +#include + +namespace WasmEdge { +namespace Host { + +Expect CloseBpfObject::body(const WasmEdge::Runtime::CallingFrame &, + handle_t program) { + std::shared_lock guard(this->state->lock); + auto &handles = this->state->handles; + if (!handles.count(program)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + return handles.erase(program) > 0 ? 0 : -1; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-close-bpf-object.h b/plugins/wasm_bpf/func-close-bpf-object.h new file mode 100644 index 00000000..ee94ba59 --- /dev/null +++ b/plugins/wasm_bpf/func-close-bpf-object.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Close an opened BPF object and remove map fds from the cache. +/// Returns 0 on success; other values represent error codes. +class CloseBpfObject : public WasmEdge::Runtime::HostFunction { +public: + CloseBpfObject(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-load-bpf-object.cpp b/plugins/wasm_bpf/func-load-bpf-object.cpp new file mode 100644 index 00000000..c9c7b682 --- /dev/null +++ b/plugins/wasm_bpf/func-load-bpf-object.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func-load-bpf-object.h" + +namespace WasmEdge { +namespace Host { + +Expect LoadBpfObject::body(const Runtime::CallingFrame &Frame, + uint32_t obj_buf, uint32_t obj_buf_sz) { + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + const auto object_buffer = memory->getSpan(obj_buf, obj_buf_sz); + if (object_buffer.size() != obj_buf_sz) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto program = std::make_unique(); + int32_t res = + program->load_bpf_object(object_buffer.data(), object_buffer.size()); + if (res < 0) + return 0; + auto key = reinterpret_cast(program.get()); + + std::shared_lock guard(state->lock); + state->handles.emplace(key, std::move(program)); + return key; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-load-bpf-object.h b/plugins/wasm_bpf/func-load-bpf-object.h new file mode 100644 index 00000000..eb51cd8e --- /dev/null +++ b/plugins/wasm_bpf/func-load-bpf-object.h @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Load a BPF ELF file. +/// +/// A binary file should be provided through a Wasm buffer. wasm_bpf handles +/// the remaining process. Calling this function also caches BPF map fds. +/// +/// \return a handle to a BPF program, which is stored in a map in the global +/// state. Returns 0 on failure. +class LoadBpfObject : public WasmEdge::Runtime::HostFunction { +public: + LoadBpfObject(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + uint32_t obj_buf, uint32_t obj_buf_sz); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/state.h b/plugins/wasm_bpf/state.h new file mode 100644 index 00000000..2b15cd1e --- /dev/null +++ b/plugins/wasm_bpf/state.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "bpf-api.h" +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +struct WasmBpfState { + /// Manage BPF programs. + std::unordered_map> handles; + std::shared_mutex lock; + ~WasmBpfState() noexcept = default; +}; + +using state_t = std::shared_ptr; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/util.cpp b/plugins/wasm_bpf/util.cpp new file mode 100644 index 00000000..b3c45b7b --- /dev/null +++ b/plugins/wasm_bpf/util.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "util.h" + +namespace WasmEdge { +namespace Host { + +Expect read_c_str(Runtime::Instance::MemoryInstance *memory, + uint32_t ptr) { + uint32_t tail = ptr; + while (true) { + auto ch = memory->getBytes(tail, 1); + if (!ch.has_value()) + return Unexpect(ch.error()); + if (ch.value()[0] == '\0') + break; + tail++; + } + uint32_t len = tail - ptr + 1; + return memory->getSpan(ptr, len).data(); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/util.h b/plugins/wasm_bpf/util.h new file mode 100644 index 00000000..d55161d9 --- /dev/null +++ b/plugins/wasm_bpf/util.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "runtime/instance/memory.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Read a C string from memory and check whether it is null terminated. +/// \param memory memory instance from the wasm runtime. +/// \param ptr the wasm32 buffer pointer +/// \return +WasmEdge::Expect +read_c_str(WasmEdge::Runtime::Instance::MemoryInstance *memory, uint32_t ptr); + +/// Check exist and set cstr, or return `Unexpect`. +#define checkAndSetCstr(memory, name, name_str) \ + do { \ + if (auto res = read_c_str(memory, name); unlikely(!res)) { \ + return Unexpect(res.error()); \ + } else { \ + name_str = *res; \ + } \ + } while (0) + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/wasm-bpf-module.cpp b/plugins/wasm_bpf/wasm-bpf-module.cpp new file mode 100644 index 00000000..a240e6a4 --- /dev/null +++ b/plugins/wasm_bpf/wasm-bpf-module.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasm-bpf-module.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-buffer-poll.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-bpf-map-operate.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "state.h" +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals::string_view_literals; + +WasmBpfModule::WasmBpfModule() : ModuleInstance("wasm_bpf") { + state_t state = std::make_shared(); + addHostFunc("wasm_load_bpf_object", std::make_unique(state)); + addHostFunc("wasm_close_bpf_object", std::make_unique(state)); + addHostFunc("wasm_attach_bpf_program", + std::make_unique(state)); + addHostFunc("wasm_bpf_buffer_poll", std::make_unique(state)); + addHostFunc("wasm_bpf_map_fd_by_name", + std::make_unique(state)); + addHostFunc("wasm_bpf_map_operate", std::make_unique(state)); +} + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmBpfModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasm_bpf", + .Description = "A plugin provides API for eBPF", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasm_bpf", + .Description = "Provide functions for eBPF", + .Create = create, + }, + }, + .AddOptions = nullptr}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/wasm-bpf-module.h b/plugins/wasm_bpf/wasm-bpf-module.h new file mode 100644 index 00000000..65eba255 --- /dev/null +++ b/plugins/wasm_bpf/wasm-bpf-module.h @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmBpfModule : public Runtime::Instance::ModuleInstance { +public: + WasmBpfModule(); +}; +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/wasm-bpf.cpp b/plugins/wasm_bpf/wasm-bpf.cpp new file mode 100644 index 00000000..fffe1125 --- /dev/null +++ b/plugins/wasm_bpf/wasm-bpf.cpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include +#include +#include +#include +#include + +#include "bpf-api.h" +#include "common/types.h" +#include "wasmedge/wasmedge.h" + +extern "C" { +#include +#include +} + +using namespace std::literals; + +static int32_t bpf_buffer_sample(void *ctx, void *data, size_t size); +static int32_t libbpf_print_fn(enum libbpf_print_level level, + const char *format, va_list args) { + if (level == LIBBPF_DEBUG && DEBUG_LIBBPF_RUNTIME) + return 0; + char buf[DEBUG_PRINT_BUFFER_SIZE]; + int32_t len = vsnprintf(buf, sizeof(buf), format, args); + spdlog::debug("[WasmEdge Wasm_bpf] {}"sv, buf); + return len; +} + +/// \brief perf buffer sample callback +static void perfbuf_sample_fn(void *ctx, int32_t cpu, void *data, __u32 size) { + static_cast(cpu); + bpf_buffer_sample(ctx, data, size); +} + +/// \brief sample the perf buffer and ring buffer +static int32_t bpf_buffer_sample(void *ctx, void *data, size_t size) { + WasmEdge::Host::bpf_buffer *buffer = + static_cast(ctx); + return buffer->bpf_buffer_sample(data, size); +} + +namespace WasmEdge { +namespace Host { + +/// \brief Initialize the libbpf library. +void init_libbpf(void) { + libbpf_set_strict_mode(LIBBPF_STRICT_ALL); + libbpf_set_print(libbpf_print_fn); +} + +class perf_buffer_wrapper : public bpf_buffer { + std::unique_ptr inner{ + nullptr, perf_buffer__free}; + +public: + perf_buffer_wrapper(bpf_map *events) { + bpf_map__set_type(events, BPF_MAP_TYPE_PERF_EVENT_ARRAY); + bpf_map__set_key_size(events, sizeof(int)); + bpf_map__set_value_size(events, sizeof(int)); + } + int32_t bpf_buffer__poll(int32_t timeout_ms) override { + return perf_buffer__poll(inner.get(), timeout_ms); + } + int32_t bpf_buffer__open(int32_t fd, bpf_buffer_sample_fn sample_cb, + void *ctx) override { + fn = sample_cb; + inner.reset(perf_buffer__new(fd, PERF_BUFFER_PAGES, perfbuf_sample_fn, + nullptr, ctx, nullptr)); + return inner ? 0 : -EINVAL; + } +}; + +struct ring_buffer_wrapper : public bpf_buffer { +public: + std::unique_ptr inner{ + nullptr, ring_buffer__free}; + ring_buffer_wrapper(bpf_map *events) { + bpf_map__set_autocreate(events, false); + } + int32_t bpf_buffer__poll(int32_t timeout_ms) override { + return ring_buffer__poll(inner.get(), timeout_ms); + } + int32_t bpf_buffer__open(int32_t fd, bpf_buffer_sample_fn sample_cb, + void *ctx) override { + inner.reset(ring_buffer__new(fd, sample_cb, ctx, nullptr)); + return inner ? 0 : -1; + } +}; + +void bpf_buffer::set_callback_params( + WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, uint32_t sample_func, + void *data, size_t max_size, uint32_t ctx, uint32_t buf_ptr) { + wasm_executor = executor; + wasm_module_instance = module_instance; + wasm_sample_function = sample_func; + poll_data = data; + max_poll_size = max_size; + wasm_ctx = ctx; + wasm_buf_ptr = buf_ptr; +} + +bool bpf_buffer::is_valid() const { + auto module_inst = wasm_module_instance; + WasmEdge_String names; + uint32_t exported_table_len = + WasmEdge_ModuleInstanceListTable(module_inst, &names, 1); + if (exported_table_len != 1) { + return false; + } + auto table_inst = WasmEdge_ModuleInstanceFindTable(module_inst, names); + if (!table_inst) { + return false; + } + WasmEdge_Value value; + auto get_data_result = + WasmEdge_TableInstanceGetData(table_inst, &value, wasm_sample_function); + return WasmEdge_ResultOK(get_data_result); +} + +int32_t bpf_buffer::bpf_buffer_sample(void *data, size_t size) { + size_t sample_size = size; + if (max_poll_size < size) { + sample_size = max_poll_size; + } + memcpy(poll_data, data, sample_size); + auto module_inst = wasm_module_instance; + WasmEdge_String names[1]; + /// a valid module instance should have only one table + uint32_t exported_table_len = + WasmEdge_ModuleInstanceListTable(module_inst, names, std::size(names)); + assuming(exported_table_len == 1); + auto table_inst = WasmEdge_ModuleInstanceFindTable(module_inst, names[0]); + assuming(table_inst); + WasmEdge_Value value; + auto get_data_result = + WasmEdge_TableInstanceGetData(table_inst, &value, wasm_sample_function); + assuming(WasmEdge_ResultOK(get_data_result)); + assert(value.Type == WasmEdge_ValType::WasmEdge_ValType_FuncRef); + auto func_ref = WasmEdge_ValueGetFuncRef(value); + + WasmEdge_Value invoke_func_params[3] = { + WasmEdge_ValueGenI32(wasm_ctx), + WasmEdge_ValueGenI32(wasm_buf_ptr), + WasmEdge_ValueGenI32(size), + }; + WasmEdge_Value invoke_func_result; + auto call_result = WasmEdge_ExecutorInvoke( + wasm_executor, func_ref, invoke_func_params, 3, &invoke_func_result, 1); + if (!WasmEdge_ResultOK(call_result)) { + return -EINVAL; + } + return WasmEdge_ValueGetI32(invoke_func_result); +} + +/// \brief Create a BPF buffer based on the object map type. +std::unique_ptr bpf_buffer__new(bpf_map *events) { + bpf_map_type map_type = bpf_map__type(events); + switch (map_type) { + case BPF_MAP_TYPE_PERF_EVENT_ARRAY: + return std::make_unique(events); + case BPF_MAP_TYPE_RINGBUF: + return std::make_unique(events); + default: + return nullptr; + } +} + +/// Get the file descriptor of a map by name. +int32_t wasm_bpf_program::bpf_map_fd_by_name(const char *name) { + return bpf_object__find_map_fd_by_name(obj.get(), name); +} +/// \brief Load all BPF programs and maps in an object file. +int32_t wasm_bpf_program::load_bpf_object(const void *obj_buf, + size_t obj_buf_sz) { + auto object = bpf_object__open_mem(obj_buf, obj_buf_sz, nullptr); + if (!object) { + return static_cast(libbpf_get_error(object)); + } + obj.reset(object); + return bpf_object__load(object); +} + +/// \brief Attach a specific BPF program by name and target. +int32_t wasm_bpf_program::attach_bpf_program(const char *name, + const char *attach_target) { + bpf_link *link; + if (!attach_target) { + // Auto-attach based on bpf_program__section_name. This works well for most + // BPF types, including kprobe, uprobe, fentry, lsm, etc. + link = + bpf_program__attach(bpf_object__find_program_by_name(obj.get(), name)); + } else { + bpf_object *o = obj.get(); + bpf_program *prog = bpf_object__find_program_by_name(o, name); + if (!prog) { + spdlog::error("[WasmEdge Wasm_bpf] get prog {} fail"sv, name); + return -1; + } + // TODO: attach dynamically based on bpf_program__section_name(prog) and + // attach_target to support more attach types that libbpf cannot auto + // attach. For example, if bpf_program__section_name(prog) is "xdp" and + // attach_target is "eth0", or attach sockops to a socket fd. For now, we + // will try auto attach as well. + link = + bpf_program__attach(bpf_object__find_program_by_name(obj.get(), name)); + } + if (!link) { + return static_cast(libbpf_get_error(link)); + } + links.emplace(std::unique_ptr{ + link, bpf_link__destroy}); + return 0; +} + +/// \brief get map pointer by fd through iterating over all maps +bpf_map *wasm_bpf_program::map_ptr_by_fd(int fd) { + bpf_map *curr = nullptr; + bpf_map__for_each(curr, obj.get()) { + if (bpf_map__fd(curr) == fd) { + return curr; + } + } + return nullptr; +} + +/// Poll the buffer; if the buffer is not created, create it. +int32_t wasm_bpf_program::bpf_buffer_poll( + WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, int32_t fd, + int32_t sample_func, uint32_t ctx, void *data, size_t max_size, + int32_t timeout_ms, uint32_t wasm_buf_ptr) { + int32_t res; + if (!buffer.get()) { + // create buffer + auto map = map_ptr_by_fd(fd); + buffer = bpf_buffer__new(map); + if (!buffer) { + return -1; + } + res = buffer->bpf_buffer__open(fd, bpf_buffer_sample, buffer.get()); + if (res < 0) { + return res; + } + } + buffer->set_callback_params(executor, module_instance, + static_cast(sample_func), data, + max_size, ctx, wasm_buf_ptr); + if (!buffer->is_valid()) { + return -EINVAL; + } + // poll the buffer + return buffer->bpf_buffer__poll(timeout_ms); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt new file mode 100644 index 00000000..ad7b867b --- /dev/null +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +find_package(PkgConfig REQUIRED) +pkg_check_modules(LIBAV REQUIRED IMPORTED_TARGET + libavdevice + libavfilter + libavformat + libavcodec + libswresample + libswscale + libavutil +) + +wasmedge_add_library(wasmedgePluginWasmEdgeFFmpeg + SHARED + + avcodec/avCodecContext.cpp + avcodec/avCodec.cpp + avcodec/avCodecParameters.cpp + avcodec/avPacket.cpp + avcodec/avcodec_func.cpp + avcodec/module.cpp + + avdevice/avDevice_func.cpp + avdevice/module.cpp + + avfilter/buffer_source_sink.cpp + avfilter/avFilter.cpp + avfilter/avfilter_func.cpp + avfilter/module.cpp + + avformat/avformatContext.cpp + avformat/avInputOutputFormat.cpp + avformat/avStream.cpp + avformat/avChapter.cpp + avformat/avformat_func.cpp + avformat/module.cpp + + avutil/error.cpp + avutil/avRational.cpp + avutil/avFrame.cpp + avutil/pixfmt.cpp + avutil/samplefmt.cpp + avutil/avDictionary.cpp + avutil/avTime.cpp + avutil/avutil_func.cpp + avutil/module.cpp + + swresample/swresample_func.cpp + swresample/module.cpp + + swscale/swscale_func.cpp + swscale/module.cpp + + ffmpeg_env.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeFFmpeg + PUBLIC + -DWASMEDGE_PLUGIN + -Wno-deprecated-declarations +) + +target_include_directories(wasmedgePluginWasmEdgeFFmpeg + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgePluginWasmEdgeFFmpeg + PUBLIC + PkgConfig::LIBAV +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeFFmpeg + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeFFmpeg + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeFFmpeg + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp new file mode 100644 index 00000000..565a161a --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avCodec.h" + +extern "C" { +#include "libavcodec/avcodec.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecID::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return FFmpegUtils::CodecID::fromAVCodecID(AvCodec->id); +} + +Expect AVCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return FFmpegUtils::MediaType::fromMediaType(AvCodec->type); +} + +Expect AVCodecMaxLowres::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return AvCodec->max_lowres; +} + +Expect AVCodecCapabilities::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return AvCodec->capabilities; +} + +Expect AVCodecGetNameLen::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return strlen(AvCodec->name); +} + +Expect AVCodecGetName::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, uint32_t NamePtr, + uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + + const char *Name = AvCodec->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVCodecGetLongNameLen::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return strlen(AvCodec->long_name); +} + +Expect AVCodecGetLongName::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, + uint32_t LongNamePtr, + uint32_t LongNameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + + const char *LongName = AvCodec->long_name; + std::copy_n(LongName, LongNameLen, LongNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVCodecProfiles::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->profiles) { + return 1; + } + return 0; +} + +Expect AVCodecPixFmtsIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->pix_fmts == nullptr) { + return 1; + } + return 0; +} + +Expect AVCodecPixFmtsIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, uint32_t Idx) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVPixelFormat const *PixelFormat = AvCodec->pix_fmts; + if (PixelFormat == nullptr) { + return 0; + } + + uint32_t Curr = 0; + while (Curr < Idx) { + PixelFormat++; + Curr++; + } + + return FFmpegUtils::PixFmt::fromAVPixFmt(*PixelFormat); +} + +Expect +AVCodecSupportedFrameratesIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->supported_framerates == nullptr) { + return 1; + } + return 0; +} + +Expect +AVCodecSupportedFrameratesIter::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, uint32_t Idx, + uint32_t NumPtr, uint32_t DenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(NumId, MemInst, int32_t, NumPtr, + "Failed when accessing the return NumPtr Memory"sv); + MEM_PTR_CHECK(DenId, MemInst, int32_t, DenPtr, + "Failed when accessing the return DenPtr Memory"sv); + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVRational const *Rational = AvCodec->supported_framerates; + + if (Rational == nullptr) { + *NumId = 0; + *DenId = 0; + return static_cast(ErrNo::Success); + } + + uint32_t Curr = 0; + while (Curr < Idx) { + Rational++; + Curr++; + } + + *NumId = Rational->num; + *DenId = Rational->den; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecSupportedSampleRatesIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->supported_samplerates == nullptr) { + return 1; + } + return 0; +} + +Expect +AVCodecSupportedSampleRatesIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, uint32_t Idx) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + const int32_t *SampleRates = AvCodec->supported_samplerates; + if (SampleRates == nullptr) { + return 0; + } + + uint32_t Curr = 0; + while (Curr < Idx) { + SampleRates++; + Curr++; + } + + return *SampleRates; +} + +Expect AVCodecChannelLayoutIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->ch_layouts == nullptr) { + return 1; + } + return 0; +} + +Expect AVCodecChannelLayoutIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, + uint32_t Idx) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + const AVChannelLayout *ChannelLayout = AvCodec->ch_layouts; + if (ChannelLayout == nullptr) { + return 0; + } + + uint32_t Curr = 0; + while (Curr < Idx) { + ChannelLayout++; + Curr++; + } + + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(ChannelLayout->u.mask); +} + +Expect AVCodecSampleFmtsIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->sample_fmts == nullptr) { + return 1; + } + return 0; +} + +Expect AVCodecSampleFmtsIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, uint32_t Idx) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVSampleFormat const *SampleFormat = AvCodec->sample_fmts; + if (SampleFormat == nullptr) { + return 0; + } + + uint32_t Curr = 0; + while (Curr < Idx) { + SampleFormat++; + Curr++; + } + + return FFmpegUtils::SampleFmt::toSampleID(*SampleFormat); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.h b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h new file mode 100644 index 00000000..70676c2f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecID : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecMaxLowres : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecCapabilities : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecGetNameLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecGetName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVCodecGetLongNameLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecGetLongName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t LongNamePtr, uint32_t LongNameLen); +}; + +class AVCodecProfiles : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecPixFmtsIsNull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecPixFmtsIter : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +class AVCodecSupportedFrameratesIsNull + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecSupportedFrameratesIter + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecSupportedSampleRatesIsNull + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecSupportedSampleRatesIter + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +class AVCodecChannelLayoutIsNull + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecChannelLayoutIter : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +class AVCodecSampleFmtsIsNull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecSampleFmtsIter : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp new file mode 100644 index 00000000..c50316fa --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avCodecContext.h" + +extern "C" { +#include "libavcodec/avcodec.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecCtxCodecID::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVCodecID const AvCodecId = AvCodecCtx->codec_id; + return FFmpegUtils::CodecID::fromAVCodecID(AvCodecId); +} + +Expect AVCodecCtxCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVMediaType const AvMediaType = AvCodecCtx->codec_type; + return FFmpegUtils::MediaType::fromMediaType(AvMediaType); +} + +Expect AVCodecCtxSetCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t CodecTypeId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVMediaType const AvMediaType = + FFmpegUtils::MediaType::intoMediaType(CodecTypeId); + + AvCodecCtx->codec_type = AvMediaType; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetTimebase::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Num, + int32_t Den) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVRational const Rational = av_make_q(Num, Den); + AvCodecCtx->time_base = Rational; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxTimeBase::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, + uint32_t DenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVRational const AvRational = AvCodecCtx->time_base; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxWidth::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->width; +} + +Expect AVCodecCtxSetWidth::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Width) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->width = Width; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxHeight::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->height; +} + +Expect AVCodecCtxSetHeight::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Height) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->height = Height; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSampleAspectRatio::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, + uint32_t DenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + + const AVRational AvRational = AvCodecCtx->sample_aspect_ratio; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetSampleAspectRatio::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Num, + int32_t Den) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + const AVRational AspectRatio = av_make_q(Num, Den); + AvCodecCtx->sample_aspect_ratio = AspectRatio; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxChannelLayout::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + // Deprecated method + uint64_t const AvChannel = AvCodecCtx->ch_layout.u.mask; + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(AvChannel); +} + +Expect AVCodecCtxSetChannelLayout::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint64_t ChannelLayoutId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + uint64_t const AvChannel = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + av_channel_layout_from_mask(&AvCodecCtx->ch_layout, AvChannel); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxPixFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVPixelFormat const PixFmt = AvCodecCtx->pix_fmt; + return FFmpegUtils::PixFmt::fromAVPixFmt(PixFmt); +} + +Expect AVCodecCtxSetPixFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t PixFmtId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVPixelFormat const PixFmt = FFmpegUtils::PixFmt::intoAVPixFmt(PixFmtId); + AvCodecCtx->pix_fmt = PixFmt; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSampleFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVSampleFormat const AvSampleFormat = AvCodecCtx->sample_fmt; + return FFmpegUtils::SampleFmt::toSampleID(AvSampleFormat); +} + +Expect AVCodecCtxSetSampleFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t SampleFmtId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVSampleFormat const SampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + AvCodecCtx->sample_fmt = SampleFormat; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSampleRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->sample_rate; +} + +Expect AVCodecCtxSetSampleRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t SampleRate) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->sample_rate = SampleRate; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetGopSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t GopSize) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->gop_size = GopSize; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMaxBFrames::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MaxBFrames) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->max_b_frames = MaxBFrames; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetBQuantFactor::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float BQuantFactor) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->b_quant_factor = BQuantFactor; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetBQuantOffset::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float BQuantOffset) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->b_quant_offset = BQuantOffset; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetIQuantFactor::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float IQuantFactor) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->i_quant_factor = IQuantFactor; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetIQuantOffset::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float IQuantOffset) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->i_quant_offset = IQuantOffset; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetLumiMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float LumiMasking) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->lumi_masking = LumiMasking; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetTemporalCplxMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float TemporalCplxMasking) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->temporal_cplx_masking = TemporalCplxMasking; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetSpatialCplxMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float SpatialCplxMasking) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->spatial_cplx_masking = SpatialCplxMasking; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetPMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float PMasking) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->p_masking = PMasking; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetDarkMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float DarkMasking) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->dark_masking = DarkMasking; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMeCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t MeCmp) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_cmp = MeCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMeSubCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MeSubCmp) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_sub_cmp = MeSubCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t MbCmp) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_cmp = MbCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetIldctCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t IldctCmp) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->ildct_cmp = IldctCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetDiaSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t DiaSize) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->dia_size = DiaSize; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetLastPredictorsCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t LastPredictorCount) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->last_predictor_count = LastPredictorCount; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMePreCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MePreCmp) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_pre_cmp = MePreCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetPreDiaSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t PreDiaSize) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->pre_dia_size = PreDiaSize; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetMeSubpelQuality::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MeSubpelQuality) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_subpel_quality = MeSubpelQuality; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMeRange::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MeRange) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_range = MeRange; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbDecision::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MbDecision) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_decision = MbDecision; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbLMin::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MbLMin) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_lmin = MbLMin; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbLMax::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MbLMax) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_lmax = MbLMax; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxIntraDcPrecision::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->intra_dc_precision; +} + +Expect +AVCodecCtxSetIntraDcPrecision::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t IntraDcPrecision) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->intra_dc_precision = IntraDcPrecision; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetQMin::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t QMin) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->qmin = QMin; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetQMax::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t QMax) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->qmax = QMax; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetGlobalQuality::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t GlobalQuality) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->global_quality = GlobalQuality; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetColorspace::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ColorspaceId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorSpace const ColorSpace = + FFmpegUtils::ColorSpace::intoAVColorSpace(ColorspaceId); + AvCodecCtx->colorspace = ColorSpace; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorspace::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorSpace const Colorspace = AvCodecCtx->colorspace; + return FFmpegUtils::ColorSpace::fromAVColorSpace(Colorspace); +} + +Expect AVCodecCtxSetColorRange::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ColorRangeId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->color_range = static_cast(ColorRangeId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorRange::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorRange const ColorRange = AvCodecCtx->color_range; + return static_cast(ColorRange); +} + +Expect AVCodecCtxFrameSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->frame_size; +} + +Expect AVCodecCtxBitRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->bit_rate; +} + +Expect AVCodecCtxSetBitRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int64_t BitRate) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->bit_rate = BitRate; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxRcMaxRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->rc_max_rate; +} + +Expect AVCodecCtxSetRcMaxRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int64_t RcMaxRate) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->rc_max_rate = RcMaxRate; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetBitRateTolerance::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t BitRateTolerance) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->bit_rate_tolerance = BitRateTolerance; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetCompressionLevel::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t CompressionLevel) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->compression_level = CompressionLevel; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxFrameRate::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, + uint32_t NumPtr, uint32_t DenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + + AVRational const FrameRate = AvCodecCtx->framerate; + *Num = FrameRate.num; + *Den = FrameRate.den; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetFrameRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Num, + int32_t Den) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVRational const Rational = av_make_q(Num, Den); + AvCodecCtx->framerate = Rational; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetFlags::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Flags) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->flags = Flags; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetStrictStdCompliance::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ComplianceId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->strict_std_compliance = ComplianceId; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetDebug::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Debug) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->debug = Debug; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxCodec::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, + uint32_t AvCodecPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AvCodecPtr, + "Failed to access Ptr for AvCodecPtr"sv); + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvCodec, *AVCodecId, const AVCodec); + + AvCodec = AvCodecCtx->codec; + if (AvCodec == nullptr) + return -1; + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxChannels::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->ch_layout.nb_channels; +} + +Expect AVCodecCtxSetChannels::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Channels) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->ch_layout.nb_channels = Channels; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipLoopFilter::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t AVDiscardId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_loop_filter = static_cast(AVDiscardId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipFrame::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t AVDiscardId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_frame = static_cast(AVDiscardId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipIdct::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t AVDiscardId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_idct = static_cast(AVDiscardId); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetErrorConcealment::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ErrorConcealment) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->error_concealment = ErrorConcealment; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetErrorRecognition::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ErrRecognition) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->err_recognition = ErrRecognition; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxDelay::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->delay; +} + +Expect AVCodecCtxSetSkipTop::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_top = Value; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipBottom::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_bottom = Value; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxRefs::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->refs; +} + +Expect AVCodecCtxSetSliceFlags::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->slice_flags = Value; + return static_cast(ErrNo::Success); +} +Expect AVCodecCtxSetSliceCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->slices = Value; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetFieldOrder::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->field_order = static_cast(Value); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorTrc::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return static_cast(AvCodecCtx->color_trc); +} + +Expect +AVCodecCtxChromaSampleLocation::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVChromaLocation const Chroma = AvCodecCtx->chroma_sample_location; + return FFmpegUtils::ChromaLocation::fromAVChromaLocation(Chroma); +} + +Expect AVCodecCtxFrameNumber::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->frame_num; +} + +Expect AVCodecCtxBlockAlign::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->block_align; +} + +Expect +AVCodecCtxSetRequestSampleFmt::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t SampleFmtId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + AvCodecCtx->request_sample_fmt = SampleFmt; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxAudioServiceType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVAudioServiceType const AudioServiceType = AvCodecCtx->audio_service_type; + return static_cast(AudioServiceType); +} + +Expect AVCodecCtxHasBFrames::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->has_b_frames; +} + +Expect AVCodecCtxActiveThreadType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->active_thread_type; +} + +Expect AVCodecCtxSetThreadType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ThreadType) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->thread_type = ThreadType; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxThreadCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->thread_count; +} + +Expect AVCodecCtxSetThreadCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ThreadCount) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->thread_count = ThreadCount; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorPrimaries::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorPrimaries const ColorPrimaries = AvCodecCtx->color_primaries; + return FFmpegUtils::ColorPrimaries::fromAVColorPrimaries(ColorPrimaries); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h new file mode 100644 index 00000000..d1696074 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h @@ -0,0 +1,678 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecCtxCodecID : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxCodecType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetCodecType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t CodecTypeId); +}; + +class AVCodecCtxSetTimebase : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Num, int32_t Den); +}; + +class AVCodecCtxTimeBase : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecCtxWidth : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetWidth : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Width); +}; + +class AVCodecCtxHeight : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetHeight : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Height); +}; + +class AVCodecCtxSampleAspectRatio + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecCtxSetSampleAspectRatio + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Num, int32_t Den); +}; + +class AVCodecCtxChannelLayout : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetChannelLayout + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint64_t ChannelLayoutId); +}; + +class AVCodecCtxPixFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetPixFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t PixFmtId); +}; + +class AVCodecCtxSampleFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSampleFormat + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t SampleFmtId); +}; + +class AVCodecCtxSampleRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSampleRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t SampleRate); +}; + +class AVCodecCtxSetGopSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t GopSize); +}; + +class AVCodecCtxSetMaxBFrames : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MaxBFrames); +}; + +class AVCodecCtxSetBQuantFactor + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float BQuantFactor); +}; + +class AVCodecCtxSetBQuantOffset + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float BQuantOffset); +}; + +class AVCodecCtxSetIQuantFactor + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float IQuantFactor); +}; + +class AVCodecCtxSetIQuantOffset + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float IQuantOffset); +}; + +class AVCodecCtxSetLumiMasking : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float LumiMasking); +}; + +class AVCodecCtxSetTemporalCplxMasking + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float TemporalCplxMasking); +}; + +class AVCodecCtxSetSpatialCplxMasking + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float SpatialCplxMasking); +}; + +class AVCodecCtxSetPMasking : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float PMasking); +}; + +class AVCodecCtxSetDarkMasking : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float DarkMasking); +}; + +class AVCodecCtxSetMeCmp : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeCmp); +}; + +class AVCodecCtxSetMeSubCmp : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeSubCmp); +}; + +class AVCodecCtxSetMbCmp : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbCmp); +}; + +class AVCodecCtxSetIldctCmp : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t IldctCmp); +}; + +class AVCodecCtxSetDiaSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t DiaSize); +}; + +class AVCodecCtxSetLastPredictorsCount + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t LastPredictorCount); +}; + +class AVCodecCtxSetMePreCmp : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MePreCmp); +}; + +class AVCodecCtxSetPreDiaSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t PreDiaSize); +}; + +class AVCodecCtxSetMeSubpelQuality + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeSubpelQuality); +}; + +class AVCodecCtxSetMeRange : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeRange); +}; + +class AVCodecCtxSetMbDecision : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbDecision); +}; + +class AVCodecCtxSetMbLMin : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbLMin); +}; + +class AVCodecCtxSetMbLMax : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbLMax); +}; + +class AVCodecCtxIntraDcPrecision + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetIntraDcPrecision + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t IntraDcPrecision); +}; + +class AVCodecCtxSetQMin : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t QMin); +}; + +class AVCodecCtxSetQMax : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t QMax); +}; + +class AVCodecCtxSetGlobalQuality + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t GlobalQuality); +}; + +class AVCodecCtxSetColorspace : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ColorspaceId); +}; + +class AVCodecCtxColorspace : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetColorRange : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ColorRange); +}; + +class AVCodecCtxColorRange : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxFrameSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxBitRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetBitRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int64_t BitRate); +}; + +class AVCodecCtxRcMaxRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetRcMaxRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int64_t RcMaxRate); +}; + +class AVCodecCtxSetBitRateTolerance + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t BitRateTolerance); +}; + +class AVCodecCtxSetCompressionLevel + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t CompressionLevel); +}; + +class AVCodecCtxFrameRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecCtxSetFrameRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Num, int32_t Den); +}; + +class AVCodecCtxSetFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Flags); +}; + +class AVCodecCtxSetStrictStdCompliance + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ComplianceId); +}; + +class AVCodecCtxSetDebug : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Debug); +}; + +class AVCodecCtxCodec : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t AvCodecPtr); +}; + +class AVCodecCtxChannels : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetChannels : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Channels); +}; + +class AVCodecCtxSetSkipLoopFilter + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t AVDicardId); +}; + +class AVCodecCtxSetSkipFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t AVDiscardId); +}; + +class AVCodecCtxSetSkipIdct : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t AVDicardId); +}; + +class AVCodecCtxSetErrorConcealment + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ErrorConcealment); +}; + +class AVCodecCtxSetErrorRecognition + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ErrorRecognition); +}; + +class AVCodecCtxDelay : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSkipTop : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxSetSkipBottom : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxRefs : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSliceFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Flags); +}; + +class AVCodecCtxSetSliceCount : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxSetFieldOrder : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxColorTrc : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxChromaSampleLocation + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxFrameNumber : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxBlockAlign : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetRequestSampleFmt + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t SampleFmtId); +}; + +class AVCodecCtxAudioServiceType + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxHasBFrames : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxActiveThreadType + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetThreadType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ThreadType); +}; + +class AVCodecCtxThreadCount : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetThreadCount : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ThreadCount); +}; + +class AVCodecCtxColorPrimaries : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp new file mode 100644 index 00000000..3724aa6e --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avCodecParameters.h" + +extern "C" { +#include "libavcodec/avcodec.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecParamCodecId::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId) { + FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); + return FFmpegUtils::CodecID::fromAVCodecID(AvCodecParams->codec_id); +} + +Expect AVCodecParamCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId) { + FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); + return FFmpegUtils::MediaType::fromMediaType(AvCodecParams->codec_type); +} + +Expect AVCodecParamSetCodecTag::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId, + uint32_t CodecTag) { + FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); + AvCodecParams->codec_tag = CodecTag; + return static_cast(ErrNo::Success); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h new file mode 100644 index 00000000..fc6557d3 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecParamCodecId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId); +}; + +class AVCodecParamCodecType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId); +}; + +class AVCodecParamSetCodecTag : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId, uint32_t CodecTag); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp b/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp new file mode 100644 index 00000000..2510f978 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avPacket.h" + +extern "C" { +#include "libavcodec/packet.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVPacketAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t AvPacketPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvPacketId, MemInst, uint32_t, AvPacketPtr, + "Failed when accessing the return AVCodecContext Memory"sv); + + FFMPEG_PTR_FETCH(AvPacket, *AvPacketId, AVPacket); // Initialize the packet. + AvPacket = av_packet_alloc(); + FFMPEG_PTR_STORE(AvPacket, AvPacketId); + return static_cast(ErrNo::Success); +} + +Expect AVNewPacket::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Size) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return av_new_packet(AvPacket, Size); +} + +Expect AVPacketRef::body(const Runtime::CallingFrame &, + uint32_t DestPacketId, uint32_t SrcPacketId) { + FFMPEG_PTR_FETCH(DestAvPacket, DestPacketId, AVPacket); + FFMPEG_PTR_FETCH(SrcAvPacket, SrcPacketId, AVPacket); + + return av_packet_ref(DestAvPacket, SrcAvPacket); +} + +Expect AVPacketUnref::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); // Free packet. + av_packet_unref(AvPacket); + FFMPEG_PTR_DELETE(AvPacketId); + return static_cast(ErrNo::Success); +} + +Expect AVGrowPacket::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Size) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return av_grow_packet(AvPacket, Size); +} + +Expect AVShrinkPacket::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Size) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + av_shrink_packet(AvPacket, Size); + return static_cast(ErrNo::Success); +} + +Expect AVPacketStreamIndex::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->stream_index; +} + +Expect AVPacketSetStreamIndex::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, + int32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->stream_index = StreamIdx; + return static_cast(ErrNo::Success); +} + +Expect AVPacketSize::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->size; +} + +Expect AVPacketFlags::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->flags; +} + +Expect AVPacketSetFlags::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Flags) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->flags = Flags; + return static_cast(ErrNo::Success); +} + +Expect AVPacketPos::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->pos; +} + +Expect AVPacketSetPos::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int64_t Pos) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->pos = Pos; + return static_cast(ErrNo::Success); +} + +Expect AVPacketDuration::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->duration; +} + +Expect AVPacketSetDuration::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, + int64_t Duration) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->duration = Duration; + return static_cast(ErrNo::Success); +} + +Expect AVPacketDts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->dts; +} + +Expect AVPacketSetDts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int64_t Dts) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->dts = Dts; + return static_cast(ErrNo::Success); +} + +Expect AVPacketPts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->pts; +} + +Expect AVPacketSetPts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int64_t Pts) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->pts = Pts; + return static_cast(ErrNo::Success); +} + +Expect AVPacketIsDataNull::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + if (AvPacket->data == nullptr) + return 1; + return 0; +} + +Expect AVPacketData::body(const Runtime::CallingFrame &Frame, + uint32_t AvPacketId, uint32_t DataPtr, + uint32_t DataLen) { + MEMINST_CHECK(MemInst, Frame, 0) + MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, DataPtr, DataLen, ""); + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + uint8_t *Data = AvPacket->data; + std::copy_n(Data, DataLen, Buffer.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avPacket.h b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h new file mode 100644 index 00000000..403d55a2 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVPacketAlloc : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvPacketPtr); +}; + +class AVNewPacket : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Size); +}; + +class AVPacketRef : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t DestPacketId, uint32_t SrcPacketId); +}; + +class AVPacketUnref : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVGrowPacket : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Size); +}; + +class AVShrinkPacket : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Size); +}; + +class AVPacketStreamIndex : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetStreamIndex : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t StreamIdx); +}; + +class AVPacketSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Flags); +}; + +class AVPacketPos : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetPos : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Pos); +}; + +class AVPacketDuration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetDuration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Duration); +}; + +class AVPacketDts : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetDts : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Dts); +}; + +class AVPacketPts : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetPts : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Pts); +}; + +class AVPacketIsDataNull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketData : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + uint32_t DataPtr, uint32_t DataLen); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp new file mode 100644 index 00000000..e7fb71bf --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avcodec_func.h" + +extern "C" { +#include "libavcodec/avcodec.h" +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecAllocContext3::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, + uint32_t AvCodecCtxPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvCodecCtxId, MemInst, uint32_t, AvCodecCtxPtr, + "Failed when accessing the return AVCodecContext Memory"sv); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, AVCodec); + + AVCodecContext *AvCodecCtx = avcodec_alloc_context3(AvCodec); + FFMPEG_PTR_STORE(AvCodecCtx, AvCodecCtxId); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecParametersFromContext::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return avcodec_parameters_from_context(AvCodecParam, AvCodecCtx); +} + +Expect AVCodecParametersFree::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId) { + FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); + + avcodec_parameters_free(&AvCodecParam); + FFMPEG_PTR_DELETE(AvCodecParamId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecFreeContext::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + + avcodec_free_context(&AvCodecCtx); + FFMPEG_PTR_DELETE(AvCodecCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecParametersAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvCodecParamId, MemInst, uint32_t, AvCodecParamPtr, + "Failed when accessing the return AVCodecParameters Memory"sv); + + FFMPEG_PTR_FETCH(AvCodecParam, *AvCodecParamId, AVCodecParameters); + AvCodecParam = avcodec_parameters_alloc(); + FFMPEG_PTR_STORE(AvCodecParam, AvCodecParamId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecGetType::body(const Runtime::CallingFrame &, + uint32_t AvCodecIdIndex) { + AVCodecID const AvCodecId = + FFmpegUtils::CodecID::intoAVCodecID(AvCodecIdIndex); + AVMediaType const MediaType = avcodec_get_type(AvCodecId); + return FFmpegUtils::MediaType::fromMediaType(MediaType); +} + +Expect AVCodecOpen2::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, uint32_t AvCodecId, + uint32_t AvDictionaryId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, AVCodec); + return avcodec_open2(AvCodecCtx, AvCodec, AvDictionary); +} + +Expect AVCodecFindDecoder::body(const Runtime::CallingFrame &Frame, + uint32_t ID, uint32_t AvCodecPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AvCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + + AVCodecID const Id = FFmpegUtils::CodecID::intoAVCodecID(ID); + + const AVCodec *AvCodec = avcodec_find_decoder(Id); + + // Setting AvCodec value as NULL. + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecIsEncoder::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return av_codec_is_encoder(AvCodec); +} + +Expect AVCodecIsDecoder::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return av_codec_is_decoder(AvCodec); +} + +Expect AVCodecClose::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + int Res = avcodec_close(AvCodecCtx); + FFMPEG_PTR_DELETE(AvCodecCtxId); + return Res; +} + +Expect AVCodecParametersToContext::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t AvCodecParamId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); + + return avcodec_parameters_to_context(AvCodecCtx, AvCodecParam); +} + +Expect AVCodecReceiveFrame::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return avcodec_receive_frame(AvCodecCtx, AvFrame); +} + +Expect AVCodecSendPacket::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t PacketId) { + FFMPEG_PTR_FETCH(AVCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvPacket, PacketId, + AVPacket); // Can send Null AVPacket, to close the stream. + return avcodec_send_packet(AVCodecCtx, AvPacket); +} + +Expect AVCodecFindEncoder::body(const Runtime::CallingFrame &Frame, + uint32_t ID, uint32_t AVCodecPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + + AVCodecID const Id = FFmpegUtils::CodecID::intoAVCodecID(ID); + + const AVCodec *AvCodec = avcodec_find_encoder(Id); + + // Setting AvCodec value as NULL. + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecReceivePacket::body(const Runtime::CallingFrame &, + uint32_t AVCodecCtxId, + uint32_t PacketId) { + FFMPEG_PTR_FETCH(AVCodecCtx, AVCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvPacket, PacketId, AVPacket); + return avcodec_receive_packet(AVCodecCtx, AvPacket); +} + +Expect AVCodecSendFrame::body(const Runtime::CallingFrame &, + uint32_t AVCodecCtxId, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AVCodecCtx, AVCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return avcodec_send_frame(AVCodecCtx, AvFrame); +} + +Expect +AVCodecFindDecoderByName::body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecPtr, uint32_t NamePtr, + uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + MEM_PTR_CHECK(NameId, MemInst, char, NamePtr, + "Failed when accessing the return URL memory"sv); + + std::string Name; + std::copy_n(NameId, NameLen, std::back_inserter(Name)); + + AVCodec const *AvCodec = avcodec_find_decoder_by_name(Name.c_str()); + + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecFindEncoderByName::body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecPtr, uint32_t NamePtr, + uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + MEM_PTR_CHECK(NameId, MemInst, char, NamePtr, + "Failed when accessing the return URL memory"sv); + + std::string Name; + std::copy_n(NameId, NameLen, std::back_inserter(Name)); + + AVCodec const *AvCodec = avcodec_find_encoder_by_name(Name.c_str()); + + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVPacketRescaleTs::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t SrcNum, + int32_t SrcDen, int32_t DestNum, + int32_t DestDen) { + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AVRational const Src = av_make_q(SrcNum, SrcDen); + AVRational const Dest = av_make_q(DestNum, DestDen); + + av_packet_rescale_ts(AvPacket, Src, Dest); + return static_cast(ErrNo::Success); +} + +Expect AVPacketMakeWritable::body(const Runtime::CallingFrame &, + uint32_t AVPacketId) { + FFMPEG_PTR_FETCH(AvPacket, AVPacketId, AVPacket); + return av_packet_make_writable(AvPacket); +} + +Expect AVCodecParametersCopy::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AVCodecParamId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvCodecParam, AVCodecParamId, AVCodecParameters); + + AVStream **AvStream = AvFormatCtx->streams; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return avcodec_parameters_copy((*AvStream)->codecpar, AvCodecParam); +} + +Expect AVCodecVersion::body(const Runtime::CallingFrame &) { + return avcodec_version(); +} + +Expect AVCodecFlushBuffers::body(const Runtime::CallingFrame &, + uint32_t AVCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AVCodecCtxId, AVCodecContext); + avcodec_flush_buffers(AvCodecCtx); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avcodec_configuration(); + return strlen(Config); +} + +Expect AVCodecConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avcodec_configuration(); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVCodecLicenseLength::body(const Runtime::CallingFrame &) { + const char *License = avcodec_license(); + return strlen(License); +} + +Expect AVCodecLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, uint32_t LicenseLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avcodec_license(); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h new file mode 100644 index 00000000..c560e065 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecAllocContext3 : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t AvCodecCtxPtr); +}; + +class AVCodecParametersFromContext + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId, uint32_t AvCodecCtxId); +}; + +class AVCodecParametersFree : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId); +}; + +class AVCodecFreeContext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecParametersAlloc : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamPtr); +}; + +class AVCodecGetType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecOpen2 : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t AvCodecId, + uint32_t AvDictionaryId); +}; + +class AVCodecFindDecoder : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ID, + uint32_t AvCodecId); +}; + +class AVCodecIsEncoder : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecIsDecoder : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecParametersToContext + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t AvCodecParamId); +}; + +class AVCodecReceiveFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t FrameId); +}; + +class AVCodecSendPacket : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t PacketId); +}; + +class AVCodecFindEncoder : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ID, + uint32_t AVCodecPtr); +}; + +class AVCodecReceivePacket : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecCtxId, uint32_t PacketId); +}; + +class AVCodecSendFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecCtxId, uint32_t FrameId); +}; + +class AVCodecFindDecoderByName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVCodecFindEncoderByName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVPacketRescaleTs : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVPacketId, + int32_t SrcNum, int32_t SrcDen, int32_t DestNum, + int32_t DestDen); +}; + +class AVPacketMakeWritable : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVPacketId); +}; + +class AVCodecParametersCopy : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, uint32_t AVCodecParamId, + uint32_t StreamIdx); +}; + +class AVCodecVersion : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVCodecFlushBuffers : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecCtxId); +}; + +class AVCodecConfigurationLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVCodecConfiguration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVCodecLicenseLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVCodecLicense : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.cpp b/plugins/wasmedge_ffmpeg/avcodec/module.cpp new file mode 100644 index 00000000..ddae5636 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/module.cpp @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "avCodec.h" +#include "avCodecContext.h" +#include "avCodecParameters.h" +#include "avPacket.h" +#include "avcodec_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +WasmEdgeFFmpegAVCodecModule::WasmEdgeFFmpegAVCodecModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avcodec") { + // avcodec_func.h + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_alloc_context3", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_from_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_free_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_open2", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_decoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_codec_is_encoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_codec_is_decoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_close", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_to_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_receive_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_send_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_encoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_receive_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_send_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_decoder_by_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_encoder_by_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_rescale_ts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_make_writable", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_copy", + std::make_unique(Env)); // TODO: Write Test. + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_flush_buffers", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_license", + std::make_unique(Env)); + + // avCodecContext Struct fields access + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_codec_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_codec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_codec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_time_base", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_time_base", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_sample_aspect_ratio", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_aspect_ratio", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_pix_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_pix_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_sample_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_gop_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_max_b_frames", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_factor", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_offset", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_factor", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_offset", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_lumi_masking", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_temporal_cplx_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_spatial_cplx_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_p_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_dark_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_sub_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_ildct_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_dia_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_last_predictor_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_pre_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_pre_dia_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_subpel_quality", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_decision", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmin", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmax", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_intra_dc_precision", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_intra_dc_precision", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmin", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmax", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_global_quality", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_frame_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_bit_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_rc_max_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_rc_max_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate_tolerance", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_compression_level", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_framerate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_framerate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_flags", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_strict_std_compliance", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_debug", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_codec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_loop_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_idct", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_error_concealment", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_err_recognition", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_delay", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_top", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_bottom", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_refs", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_field_order", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_color_trc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_chroma_sample_location", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_frame_number", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_block_align", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_audio_service_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_has_b_frames", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_active_thread_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_thread_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_color_primaries", + std::make_unique(Env)); + + // avCodec Struct fields access + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_max_lowres", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_capabilities", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_name_len", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_long_name_len", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_long_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_profiles", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_iter", + std::make_unique(Env)); + + // AVCodecParam Struct fields access. + addHostFunc("wasmedge_ffmpeg_avcodec_avcodecparam_codec_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodecparam_codec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodecparam_set_codec_tag", + std::make_unique(Env)); + + // AVPacket functions. + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_new_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_ref", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_unref", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_grow_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_shrink_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_stream_index", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_stream_index", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_pos", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_pos", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_dts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_dts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_is_data_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_data", + std::make_unique(Env)); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.h b/plugins/wasmedge_ffmpeg/avcodec/module.h new file mode 100644 index 00000000..b7e63947 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/module.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class WasmEdgeFFmpegAVCodecModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVCodecModule(std::shared_ptr Env); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp new file mode 100644 index 00000000..926c618f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avDevice_func.h" + +extern "C" { +#include "libavdevice/avdevice.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +Expect AVDeviceRegisterAll::body(const Runtime::CallingFrame &) { + avdevice_register_all(); + return {}; +} + +Expect AVDeviceVersion::body(const Runtime::CallingFrame &) { + return avdevice_version(); +} + +Expect AVDeviceListDevices::body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, + uint32_t AVDeviceInfoListPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVDeviceInfoListId, MemInst, uint32_t, AVDeviceInfoListPtr, "") + + FFMPEG_PTR_FETCH(AvFormatCtx, AVFormatCtxId, AVFormatContext); + + AVDeviceInfoList **AvDeviceInfoList = + static_cast(av_malloc(sizeof(AVDeviceInfoList *))); + + int Res = avdevice_list_devices(AvFormatCtx, AvDeviceInfoList); + FFMPEG_PTR_STORE(AvDeviceInfoList, AVDeviceInfoListId); + return Res; +} + +Expect AVInputAudioDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVInputAudioDeviceNext unimplemented"sv); + // av_input_audio_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVInputVideoDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVInputVideoDeviceNext unimplemented"sv); + // av_input_video_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVOutputAudioDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVOutputAudioDeviceNext unimplemented"sv); + // av_output_audio_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVOutputVideoDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVOutputVideoDeviceNext unimplemented"sv); + // av_output_video_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVDeviceFreeListDevices::body(const Runtime::CallingFrame &, + uint32_t AVDeviceInfoListId) { + FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); + avdevice_free_list_devices(AvDeviceInfoList); + FFMPEG_PTR_DELETE(AVDeviceInfoListId); + return static_cast(ErrNo::Success); +} + +Expect AVDeviceNbDevices::body(const Runtime::CallingFrame &, + uint32_t AVDeviceInfoListId) { + FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); + return (*AvDeviceInfoList)->nb_devices; +} + +Expect AVDeviceDefaultDevice::body(const Runtime::CallingFrame &, + uint32_t AVDeviceInfoListId) { + FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); + return (*AvDeviceInfoList)->default_device; +} + +Expect +AVDeviceConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avdevice_configuration(); + return strlen(Config); +} + +Expect AVDeviceConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avdevice_configuration(); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVDeviceLicenseLength::body(const Runtime::CallingFrame &) { + const char *License = avdevice_license(); + return strlen(License); +} + +Expect AVDeviceLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avdevice_license(); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h new file mode 100644 index 00000000..6b5832f5 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +class AVDeviceRegisterAll : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceVersion : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceListDevices : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, uint32_t AVDeviceInfoListPtr); +}; + +class AVInputAudioDeviceNext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &); +}; + +class AVInputVideoDeviceNext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &); +}; + +class AVOutputAudioDeviceNext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &); +}; + +class AVOutputVideoDeviceNext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &); +}; + +class AVDeviceFreeListDevices : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVDeviceInfoListId); +}; + +class AVDeviceNbDevices : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVDeviceInfoListId); +}; + +class AVDeviceDefaultDevice : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVDeviceInfoListId); +}; + +class AVDeviceConfigurationLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceConfiguration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVDeviceLicenseLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceLicense : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/module.cpp b/plugins/wasmedge_ffmpeg/avdevice/module.cpp new file mode 100644 index 00000000..0e58d788 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/module.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "avDevice_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +WasmEdgeFFmpegAVDeviceModule::WasmEdgeFFmpegAVDeviceModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avdevice") { + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_register_all", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_list_devices", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_free_list_devices", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_nb_devices", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_default_device", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_license", + std::make_unique(Env)); +} + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/module.h b/plugins/wasmedge_ffmpeg/avdevice/module.h new file mode 100644 index 00000000..26ed72df --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/module.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +class WasmEdgeFFmpegAVDeviceModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVDeviceModule(std::shared_ptr Env); +}; + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp b/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp new file mode 100644 index 00000000..9020b759 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avFilter.h" + +extern "C" { +#include "libavfilter/avfilter.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +Expect AVFilterNameLength::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return strlen(Filter->name); +} + +Expect AVFilterName::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t NamePtr, + uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const char *Name = Filter->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterDescriptionLength::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return strlen(Filter->description); +} + +Expect AVFilterDescription::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t DescPtr, + uint32_t DescLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(DescBuf, MemInst, char, DescPtr, DescLen, ""); + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const char *Desc = Filter->description; + std::copy_n(Desc, DescLen, DescBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterNbInputs::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return Filter->nb_inputs; +} + +Expect AVFilterNbOutputs::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return Filter->nb_outputs; +} + +Expect AVFilterFlags::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return Filter->flags; +} + +Expect AVFilterInOutSetName::body(const Runtime::CallingFrame &Frame, + uint32_t InOutId, uint32_t NamePtr, + uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + + std::string Name; + std::copy_n(NameBuf.data(), NameLen, std::back_inserter(Name)); + char *CName = av_strdup(Name.c_str()); + if (CName == nullptr) { + return static_cast(ErrNo::Success); + } + InOut->name = CName; + return static_cast(ErrNo::Success); +} + +Expect AVFilterInOutSetFilterCtx::body(const Runtime::CallingFrame &, + uint32_t InOutId, + uint32_t FilterCtxId) { + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + FFMPEG_PTR_FETCH(FilterCtx, FilterCtxId, AVFilterContext); + + InOut->filter_ctx = FilterCtx; + return static_cast(ErrNo::Success); +} + +Expect AVFilterInOutSetPadIdx::body(const Runtime::CallingFrame &, + uint32_t InOutId, int32_t PadIdx) { + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + InOut->pad_idx = PadIdx; + return static_cast(ErrNo::Success); +} + +Expect AVFilterInOutSetNext::body(const Runtime::CallingFrame &, + uint32_t InOutId, + uint32_t NextInOutId) { + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + FFMPEG_PTR_FETCH(NextInOut, NextInOutId, AVFilterInOut); + InOut->next = NextInOut; + return static_cast(ErrNo::Success); +} + +Expect +AVFilterGetInputsFilterPad::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t FilterPadPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FilterPadId, MemInst, uint32_t, FilterPadPtr, "") + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const AVFilterPad *FilterPad = Filter->inputs; + if (FilterPad == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_STORE(const_cast(FilterPad), FilterPadId); + return static_cast(ErrNo::Success); +} + +Expect +AVFilterGetOutputsFilterPad::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t FilterPadPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FilterPadId, MemInst, uint32_t, FilterPadPtr, "") + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const AVFilterPad *FilterPad = Filter->outputs; + if (FilterPad == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_STORE(const_cast(FilterPad), FilterPadId); + return static_cast(ErrNo::Success); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avFilter.h b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h new file mode 100644 index 00000000..8141d3d7 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class AVFilterNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVFilterDescriptionLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterDescription : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t DescPtr, uint32_t DescLen); +}; + +class AVFilterNbInputs : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterNbOutputs : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterInOutSetName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVFilterInOutSetFilterCtx + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + uint32_t FilterCtxId); +}; + +class AVFilterInOutSetPadIdx : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + int32_t PadIdx); +}; + +class AVFilterInOutSetNext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + uint32_t NextInOutId); +}; + +class AVFilterGetInputsFilterPad + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t FilterPadPtr); +}; + +class AVFilterGetOutputsFilterPad + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t FilterPadPtr); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp new file mode 100644 index 00000000..8c46d501 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avfilter_func.h" + +extern "C" { +#include "libavfilter/avfilter.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +Expect AVFilterGraphAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FilterGraphId, MemInst, uint32_t, FilterGraphPtr, "") + + FFMPEG_PTR_FETCH(FilterGraph, *FilterGraphId, AVFilterGraph); + + FilterGraph = avfilter_graph_alloc(); + if (FilterGraph == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_STORE(FilterGraph, FilterGraphId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphConfig::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + return avfilter_graph_config(FilterGraph, + nullptr); // log_ctx always NULL on Rust SDK. +} + +Expect AVFilterGraphFree::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + avfilter_graph_free(&FilterGraph); + FFMPEG_PTR_DELETE(FilterGraphId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphGetFilter::body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxPtr, + uint32_t FilterGraphId, + uint32_t NamePtr, + uint32_t NameSize) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(NameId, MemInst, char, NamePtr, + "Failed when accessing the return Name memory"sv); + MEM_PTR_CHECK(FilterCtxId, MemInst, uint32_t, FilterCtxPtr, ""); + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + FFMPEG_PTR_FETCH(FilterCtx, *FilterCtxId, AVFilterContext); + + std::string Name; + std::copy_n(NameId, NameSize, std::back_inserter(Name)); + + FilterCtx = avfilter_graph_get_filter(FilterGraph, Name.c_str()); + if (FilterCtx == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_STORE(FilterCtx, FilterCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphParsePtr::body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, + uint32_t FiltersString, + uint32_t FiltersSize, + uint32_t InputsId, + uint32_t OutputsId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FiltersId, MemInst, char, FiltersString, ""); + + FFMPEG_PTR_FETCH(Inputs, InputsId, AVFilterInOut); + FFMPEG_PTR_FETCH(Outputs, OutputsId, AVFilterInOut); + FFMPEG_PTR_FETCH(FiltersGraph, FilterGraphId, AVFilterGraph); + + std::string Filters; + std::copy_n(FiltersId, FiltersSize, std::back_inserter(Filters)); + return avfilter_graph_parse_ptr(FiltersGraph, Filters.c_str(), &Inputs, + &Outputs, nullptr); +} + +Expect AVFilterInOutFree::body(const Runtime::CallingFrame &, + uint32_t InOutId) { + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + avfilter_inout_free(&InOut); + FFMPEG_PTR_DELETE(InOutId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterVersion::body(const Runtime::CallingFrame &) { + return avfilter_version(); +} + +Expect AVFilterGetByName::body(const Runtime::CallingFrame &Frame, + uint32_t FilterPtr, uint32_t StrPtr, + uint32_t StrLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(StrId, MemInst, char, StrPtr, + "Failed when accessing the return Str memory"sv); + MEM_PTR_CHECK(FilterId, MemInst, uint32_t, FilterPtr, + "Failed when accessing the return Filter memory"sv); + + FFMPEG_PTR_FETCH(Filter, *FilterId, const struct AVFilter); + std::string Name; + std::copy_n(StrId, StrLen, std::back_inserter(Name)); + + Filter = avfilter_get_by_name(Name.c_str()); + if (Filter == nullptr) { + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(Filter), FilterId); + return static_cast(ErrNo::Success); +} + +Expect +AVFilterConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avfilter_configuration(); + return strlen(Config); +} + +Expect AVFilterConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avfilter_configuration(); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterLicenseLength::body(const Runtime::CallingFrame &) { + const char *License = avfilter_license(); + return strlen(License); +} + +Expect AVFilterLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avfilter_license(); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphCreateFilter::body( + const Runtime::CallingFrame &Frame, uint32_t FilterCtxPtr, + uint32_t FilterId, uint32_t NamePtr, uint32_t NameLen, uint32_t ArgsPtr, + uint32_t ArgsLen, uint32_t FilterGraphId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + MEM_SPAN_CHECK(ArgsBuf, MemInst, char, ArgsPtr, ArgsLen, ""); + MEM_PTR_CHECK(FilterCtxId, MemInst, uint32_t, FilterCtxPtr, "") + + FFMPEG_PTR_FETCH(FilterCtx, *FilterCtxId, AVFilterContext); + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + + std::string Name; + std::string Args; + std::copy_n(NameBuf.data(), NameLen, std::back_inserter(Name)); + std::copy_n(ArgsBuf.data(), ArgsLen, std::back_inserter(Args)); + + int Res = avfilter_graph_create_filter(&FilterCtx, Filter, Name.c_str(), + Args.c_str(), nullptr, FilterGraph); + if (Res < 0) { + return Res; + } + + FFMPEG_PTR_STORE(FilterCtx, FilterCtxId); + return Res; +} + +Expect AVFilterInOutAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t InOutPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(InOutId, MemInst, uint32_t, InOutPtr, "") + + FFMPEG_PTR_FETCH(InOut, *InOutId, AVFilterInOut); + InOut = avfilter_inout_alloc(); + if (InOut == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_STORE(InOut, InOutId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterPadGetNameLength::body(const Runtime::CallingFrame &, + uint32_t FilterPadId, + int32_t Idx) { + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + + const char *Name = avfilter_pad_get_name(FilterPad, Idx); + return strlen(Name); +} + +Expect AVFilterPadGetName::body(const Runtime::CallingFrame &Frame, + uint32_t FilterPadId, int32_t Idx, + uint32_t NamePtr, uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + + const char *Name = avfilter_pad_get_name(FilterPad, Idx); + auto Actual = std::strlen(Name); + auto N = std::min(NameLen, static_cast(Actual + 1)); + std::copy_n(Name, N, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterPadGetType::body(const Runtime::CallingFrame &, + uint32_t FilterPadId, int32_t Idx) { + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + AVMediaType const MediaType = avfilter_pad_get_type(FilterPad, Idx); + return FFmpegUtils::MediaType::fromMediaType(MediaType); +} + +Expect AVFilterGraphDumpLength::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + char *Graph = avfilter_graph_dump(FilterGraph, nullptr); + return strlen(Graph); +} + +Expect AVFilterGraphDump::body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, + uint32_t GraphStrPtr, + uint32_t GraphStrLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(GraphStr, MemInst, char, GraphStrPtr, GraphStrLen, ""); + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + + char *Graph = avfilter_graph_dump(FilterGraph, nullptr); + std::copy_n(Graph, GraphStrLen, GraphStr.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterFreeGraphStr::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + + char *Graph = avfilter_graph_dump(FilterGraph, nullptr); + av_free(Graph); + return static_cast(ErrNo::Success); +} + +Expect AVFilterDrop::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + if (Filter == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_DELETE(FilterId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterPadDrop::body(const Runtime::CallingFrame &, + uint32_t FilterPadId) { + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + if (FilterPad == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_DELETE(FilterPadId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterContextDrop::body(const Runtime::CallingFrame &, + uint32_t FilterCtxId) { + FFMPEG_PTR_FETCH(FilterCtx, FilterCtxId, AVFilterContext); + if (FilterCtx == nullptr) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_DELETE(FilterCtxId); + return static_cast(ErrNo::Success); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h new file mode 100644 index 00000000..6f8b486d --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class AVFilterGraphAlloc : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphPtr); +}; + +class AVFilterGraphConfig : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterGraphFree : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterGraphGetFilter : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxPtr, uint32_t FilterGraphId, + uint32_t NamePtr, uint32_t NameSize); +}; + +class AVFilterGraphParsePtr : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, uint32_t FiltersString, + uint32_t FiltersSize, uint32_t InputsId, + uint32_t OutputsId); +}; + +class AVFilterInOutFree : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId); +}; + +class AVFilterVersion : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFilterGetByName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPtr, + uint32_t StrPtr, uint32_t StrLen); +}; + +class AVFilterConfigurationLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFilterConfiguration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVFilterLicenseLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFilterLicense : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +class AVFilterGraphCreateFilter + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxPtr, uint32_t FilterId, + uint32_t NamePtr, uint32_t NameLen, uint32_t ArgsPtr, + uint32_t ArgsLen, uint32_t FilterGraphId); +}; + +class AVFilterInOutAlloc : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutPtr); +}; + +class AVFilterPadGetNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, + int32_t Idx); +}; + +class AVFilterPadGetName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, + int32_t Idx, uint32_t NamePtr, uint32_t NameLen); +}; + +class AVFilterPadGetType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, + int32_t Idx); +}; + +class AVFilterGraphDumpLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterGraphDump : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, uint32_t GraphStrPtr, + uint32_t GraphStrLen); +}; + +class AVFilterFreeGraphStr : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterDrop : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterPadDrop : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterPadId); +}; + +class AVFilterContextDrop : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxId); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp new file mode 100644 index 00000000..5753477d --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "buffer_source_sink.h" + +extern "C" { +#include "libavfilter/buffersink.h" +#include "libavfilter/buffersrc.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +Expect AVBufferSinkGetFrame::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); + return av_buffersink_get_frame(FilterCtx, Frame); +} + +Expect AVBufferSinkGetSamples::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + uint32_t FrameId, + int32_t Samples) { + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); + return av_buffersink_get_samples(FilterCtx, Frame, Samples); +} + +Expect AvBufferSinkSetFrameSize::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + int32_t Value) { + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + av_buffersink_set_frame_size(FilterCtx, Value); + return static_cast(ErrNo::Success); +} + +Expect +AVBufferSrcGetNbFailedRequests::body(const Runtime::CallingFrame &, + uint32_t FilterContextId) { + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + return av_buffersrc_get_nb_failed_requests(FilterCtx); +} + +Expect AVBufferSrcAddFrame::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); + return av_buffersrc_add_frame(FilterCtx, Frame); +} + +Expect AVBufferSrcClose::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, int64_t Pts, + uint32_t Flags) { + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + return av_buffersrc_close(FilterCtx, Pts, Flags); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h new file mode 100644 index 00000000..40db402b --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class AVBufferSinkGetFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, uint32_t FrameId); +}; + +class AVBufferSinkGetSamples : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, uint32_t FrameId, + int32_t Samples); +}; + +class AvBufferSinkSetFrameSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, int32_t Value); +}; + +class AVBufferSrcGetNbFailedRequests + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId); +}; + +class AVBufferSrcAddFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, uint32_t FrameId); +}; + +class AVBufferSrcClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, int64_t Pts, uint32_t Flags); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/module.cpp b/plugins/wasmedge_ffmpeg/avfilter/module.cpp new file mode 100644 index 00000000..761d5fba --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/module.cpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "avFilter.h" +#include "avfilter_func.h" +#include "buffer_source_sink.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +WasmEdgeFFmpegAVFilterModule::WasmEdgeFFmpegAVFilterModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avfilter") { + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_config", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_get_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_parse_ptr", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_get_by_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_license", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_create_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_get_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_get_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_get_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_dump_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_dump", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_free_graph_str", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_drop", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_drop", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_context_drop", + std::make_unique(Env)); + + // buffersrc.h && buffersink.h + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersink_get_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersink_get_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersink_set_frame_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersrc_get_nb_failed_requests", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersrc_add_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersrc_close", + std::make_unique(Env)); + + // avfilter.h + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_description_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_description", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_nb_inputs", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_nb_outputs", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_filter_ctx", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_pad_idx", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_next", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_get_inputs_filter_pad", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_get_outputs_filter_pad", + std::make_unique(Env)); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/module.h b/plugins/wasmedge_ffmpeg/avfilter/module.h new file mode 100644 index 00000000..176704fa --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/module.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class WasmEdgeFFmpegAVFilterModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVFilterModule(std::shared_ptr Env); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp b/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp new file mode 100644 index 00000000..f0657a52 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avChapter.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVChapterId::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, uint32_t ChapterIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + return static_cast(*AvChapter)->id; +} + +Expect AVChapterSetId::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, int64_t ChapterId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + (*AvChapter)->id = ChapterId; + return static_cast(ErrNo::Success); +} + +Expect AVChapterTimebase::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + AVRational const AvRational = static_cast(*AvChapter)->time_base; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVChapterSetTimebase::body(const Runtime::CallingFrame &, + int32_t Num, int32_t Den, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVRational const Timebase = av_make_q(Num, Den); + + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + (*AvChapter)->time_base = Timebase; + return static_cast(ErrNo::Success); +} + +Expect AVChapterStart::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + return static_cast(*AvChapter)->start; +} + +Expect AVChapterSetStart::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, + int64_t StartValue) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + (*AvChapter)->start = StartValue; + return static_cast(ErrNo::Success); +} + +Expect AVChapterEnd::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + return static_cast(*AvChapter)->end; +} + +Expect AVChapterSetEnd::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, int64_t EndValue) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + (*AvChapter)->end = EndValue; + return static_cast(ErrNo::Success); +} + +Expect AVChapterMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, uint32_t DictPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictionary memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + AVChapter **AvChapter = AvFormatCtx->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + *AvDictionary = (*AvChapter)->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVChapterSetMetadata::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, + uint32_t DictId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + + AVChapter **AvChapter = AvFormatCtx->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) { + AvChapter++; + } + + if (AvDictionary == nullptr) { + (*AvChapter)->metadata = nullptr; + } else { + (*AvChapter)->metadata = *AvDictionary; + } + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avChapter.h b/plugins/wasmedge_ffmpeg/avformat/avChapter.h new file mode 100644 index 00000000..822c7367 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avChapter.h @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVChapterId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx); +}; + +class AVChapterSetId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + int64_t ChapterId); +}; + +class AVChapterTimebase : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t ChapterIdx); +}; + +class AVChapterSetTimebase : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t Num, + int32_t Den, uint32_t AvFormatCtxId, + uint32_t ChapterIdx); +}; + +class AVChapterStart : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx); +}; + +class AVChapterSetStart : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + int64_t StartValue); +}; + +class AVChapterEnd : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx); +}; + +class AVChapterSetEnd : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + int64_t EndValue); +}; + +class AVChapterMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + uint32_t DictPtr); +}; + +class AVChapterSetMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + uint32_t DictId); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp new file mode 100644 index 00000000..826da201 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avInputOutputFormat.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVIOFormatNameLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + const char *Name; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + Name = AvInputFormat->name; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + Name = AvOutputFormat->name; + } + + if (Name == nullptr) { + return 0; + } + return strlen(Name); +} + +Expect AVInputFormatName::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, + uint32_t NamePtr, uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *Name = AvInputFormat->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatName::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t NamePtr, uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *Name = AvOutputFormat->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVIOFormatLongNameLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + const char *LongName; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + LongName = AvInputFormat->long_name; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + LongName = AvOutputFormat->long_name; + } + + if (LongName == nullptr) { + return 0; + } + return strlen(LongName); +} + +Expect AVInputFormatLongName::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, + uint32_t LongNamePtr, + uint32_t LongNameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *LongName = AvInputFormat->long_name; + std::copy_n(LongName, LongNameLen, LongNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatLongName::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t LongNamePtr, + uint32_t LongNameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *LongName = AvOutputFormat->long_name; + std::copy_n(LongName, LongNameLen, LongNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVIOFormatExtensionsLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + const char *Extensions; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + Extensions = AvInputFormat->extensions; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + Extensions = AvOutputFormat->extensions; + } + + if (Extensions == nullptr) { + return 0; + } + return strlen(Extensions); +} + +Expect +AVInputFormatExtensions::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t ExtensionsPtr, + uint32_t ExtensionsLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ExtensionsBuf, MemInst, char, ExtensionsPtr, ExtensionsLen, + ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *Extensions = AvInputFormat->extensions; + std::copy_n(Extensions, ExtensionsLen, ExtensionsBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect +AVOutputFormatExtensions::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t ExtensionsPtr, uint32_t ExtensionsLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ExtensionsBuf, MemInst, char, ExtensionsPtr, ExtensionsLen, + ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *Extensions = AvOutputFormat->extensions; + std::copy_n(Extensions, ExtensionsLen, ExtensionsBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVIOFormatMimeTypeLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + const char *MimeType; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + MimeType = AvInputFormat->mime_type; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + MimeType = AvOutputFormat->mime_type; + } + + if (MimeType == nullptr) { + return 0; + } + return strlen(MimeType); +} + +Expect AVInputFormatMimeType::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, + uint32_t MimeTypePtr, + uint32_t MimeTypeLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, MimeTypeLen, ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *MimeType = AvInputFormat->mime_type; + std::copy_n(MimeType, MimeTypeLen, MimeTypeBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatMimeType::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t MimeTypePtr, + uint32_t MimeTypeLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, MimeTypeLen, ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *MimeType = AvOutputFormat->mime_type; + std::copy_n(MimeType, MimeTypeLen, MimeTypeBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatFlags::body(const Runtime::CallingFrame &, + uint32_t AVOutputFormatId) { + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + return AvOutputFormat->flags; +} + +Expect AVInputOutputFormatFree::body(const Runtime::CallingFrame &, + uint32_t AVInputOutputId) { + FFMPEG_PTR_DELETE(AVInputOutputId); + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h new file mode 100644 index 00000000..8f01ba9d --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVIOFormatNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t NamePtr, + uint32_t NameLen); +}; + +class AVOutputFormatName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t NamePtr, + uint32_t NameLen); +}; + +class AVIOFormatLongNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatLongName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t LongNamePtr, + uint32_t LongNameLen); +}; + +class AVOutputFormatLongName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t LongNamePtr, + uint32_t LongNameLen); +}; + +class AVIOFormatExtensionsLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatExtensions : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t Extensions, + uint32_t ExtensionsLen); +}; + +class AVOutputFormatExtensions : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t Extensions, + uint32_t ExtensionsLen); +}; + +class AVIOFormatMimeTypeLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatMimeType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t MimeTypePtr, + uint32_t MimeTypeLen); +}; + +class AVOutputFormatMimeType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t MimeTypePtr, + uint32_t MimeTypeLen); +}; + +class AVOutputFormatFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId); +}; + +class AVInputOutputFormatFree : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputOutputId); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avStream.cpp b/plugins/wasmedge_ffmpeg/avformat/avStream.cpp new file mode 100644 index 00000000..6f1c60ce --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avStream.cpp @@ -0,0 +1,290 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avStream.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVStreamId::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + return static_cast(*AvStream)->id; +} + +Expect AVStreamIndex::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + return static_cast(*AvStream)->index; +} + +Expect AVStreamCodecPar::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, + uint32_t CodecParameterPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CodecParamId, MemInst, uint32_t, CodecParameterPtr, + "Failed when accessing the return CodecParameter Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVCodecParameters *CodecParam = + (static_cast(*AvStream))->codecpar; + FFMPEG_PTR_STORE(CodecParam, CodecParamId); + return static_cast(ErrNo::Success); +} + +Expect AVStreamTimebase::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVRational const AvRational = static_cast(*AvStream)->time_base; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetTimebase::body(const Runtime::CallingFrame &, + uint32_t Num, uint32_t Den, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVRational const Timebase = av_make_q(Num, Den); + (*AvStream)->time_base = Timebase; + return static_cast(ErrNo::Success); +} + +Expect AVStreamDuration::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + return static_cast(*AvStream)->duration; +} + +Expect AVStreamStartTime::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + return static_cast(*AvStream)->start_time; +} + +Expect AVStreamNbFrames::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + return static_cast(*AvStream)->nb_frames; +} + +Expect AVStreamDisposition::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + return static_cast(*AvStream)->disposition; +} + +Expect AVStreamRFrameRate::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVRational const AvRational = + static_cast(*AvStream)->r_frame_rate; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetRFrameRate::body(const Runtime::CallingFrame &, + int32_t Num, int32_t Den, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVRational const RFrameRate = av_make_q(Num, Den); + (*AvStream)->r_frame_rate = RFrameRate; + return static_cast(ErrNo::Success); +} + +Expect AVStreamAvgFrameRate::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVRational const AvRational = + static_cast(*AvStream)->avg_frame_rate; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetAvgFrameRate::body(const Runtime::CallingFrame &, + int32_t Num, int32_t Den, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVRational const AvgFrameRate = av_make_q(Num, Den); + (*AvStream)->avg_frame_rate = AvgFrameRate; + return static_cast(ErrNo::Success); +} + +Expect AVStreamMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, uint32_t DictPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictPtr Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + + *AvDictionary = (*AvStream)->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetMetadata::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, uint32_t DictId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + if (AvDictionary == nullptr) { + (*AvStream)->metadata = nullptr; + } else { + (*AvStream)->metadata = *AvDictionary; + } + + return static_cast(ErrNo::Success); +} + +Expect AVStreamDiscard::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) { + AvStream++; + } + + return static_cast((*AvStream)->discard); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avStream.h b/plugins/wasmedge_ffmpeg/avformat/avStream.h new file mode 100644 index 00000000..282d23b0 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avStream.h @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVStreamId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamIndex : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamCodecPar : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx, + uint32_t CodecParameterPtr); +}; + +class AVStreamTimebase : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamSetTimebase : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t Num, + uint32_t Den, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamDuration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamStartTime : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamNbFrames : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamDisposition : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamRFrameRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamSetRFrameRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t Num, + int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamAvgFrameRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamSetAvgFrameRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t Num, + int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx, + uint32_t DictPtr); +}; + +class AVStreamSetMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx, + uint32_t DictId); +}; + +class AVStreamDiscard : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp b/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp new file mode 100644 index 00000000..6ab0e259 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avformatContext.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVFormatCtxIFormat::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t AvInputFormatPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvInputFormatId, MemInst, uint32_t, AvInputFormatPtr, + "Failed when accessing the return AVInputFormat Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVInputFormat const *AvInputFormat = AvFormatCtx->iformat; + FFMPEG_PTR_STORE(const_cast(AvInputFormat), AvInputFormatId); + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxOFormat::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t AvOutputFormatPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvOutputFormatId, MemInst, uint32_t, AvOutputFormatPtr, + "Failed when accessing the return AVOutputFormat Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVOutputFormat const *AvOutputFormat = AvFormatCtx->oformat; + FFMPEG_PTR_STORE(const_cast(AvOutputFormat), + AvOutputFormatId); + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxProbeScore::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->probe_score; +} + +Expect AVFormatCtxNbStreams::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->nb_streams; +}; + +Expect AVFormatCtxBitRate::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->bit_rate; +} + +Expect AVFormatCtxDuration::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->duration; +} + +Expect AVFormatCtxNbChapters::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->nb_chapters; +} + +Expect AVFormatCtxSetNbChapters::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t NbChapters) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AvFormatContext->nb_chapters = NbChapters; + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t DictPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictionary memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + + *AvDictionary = AvFormatCtx->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxSetMetadata::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t DictId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + + if (AvDictionary == nullptr) { + AvFormatCtx->metadata = nullptr; + } else { + AvFormatCtx->metadata = *AvDictionary; + } + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformatContext.h b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h new file mode 100644 index 00000000..a0d3b1b2 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVFormatCtxIFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t AvInputFormatPtr); +}; + +class AVFormatCtxOFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t AvOutputFormatPtr); +}; + +class AVFormatCtxProbeScore : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxNbStreams : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxBitRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxDuration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxNbChapters : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxSetNbChapters : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t NbChapters); +}; + +class AVFormatCtxMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t DictPtr); +}; + +class AVFormatCtxSetMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t DictId); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp new file mode 100644 index 00000000..331235d1 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avformat_func.h" + +extern "C" { +#include "libavcodec/packet.h" +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVFormatOpenInput::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr, + uint32_t UrlPtr, uint32_t UrlSize, + uint32_t AvInputFormatId, + uint32_t AvDictionaryId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(urlId, MemInst, char, UrlPtr, + "Failed when accessing the return URL memory"sv); + MEM_PTR_CHECK(AvFormatCtxId, MemInst, uint32_t, AvFormatCtxPtr, + "Failed when accessing the return AVFormatContext Memory"sv); + + std::string TargetUrl; + std::copy_n(urlId, UrlSize, std::back_inserter(TargetUrl)); + + AVFormatContext *AvFormatContext = nullptr; + FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); + FFMPEG_PTR_FETCH(AvInputFormat, AvInputFormatId, AVInputFormat); + + int const Res = avformat_open_input(&AvFormatContext, TargetUrl.c_str(), + AvInputFormat, AvDictionary); + FFMPEG_PTR_STORE(AvFormatContext, AvFormatCtxId); + return Res; +} + +Expect AVFormatFindStreamInfo::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvDictionaryId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); + return avformat_find_stream_info(AvFormatContext, AvDictionary); +} + +Expect AVFormatCloseInput::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + avformat_close_input(&AvFormatCtx); + FFMPEG_PTR_DELETE(AvFormatCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVReadPause::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return av_read_pause(AvFormatContext); +} + +Expect AVReadPlay::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return av_read_play(AvFormatContext); +} + +Expect AVFormatSeekFile::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, int64_t MinTs, + int64_t Ts, int64_t MaxTs, + int32_t Flags) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return avformat_seek_file(AvFormatContext, StreamIdx, MinTs, Ts, MaxTs, + Flags); +} + +Expect AVDumpFormat::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, int32_t Idx, + uint32_t UrlPtr, uint32_t UrlSize, + int32_t IsOutput) { + std::string TargetUrl; + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(UrlBuf, MemInst, char, UrlPtr, ""); + + std::copy_n(UrlBuf, UrlSize, std::back_inserter(TargetUrl)); + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + av_dump_format(AvFormatCtx, Idx, TargetUrl.c_str(), IsOutput); + return static_cast(ErrNo::Success); +} + +Expect AVFormatFreeContext::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + avformat_free_context(AvFormatCtx); + FFMPEG_PTR_DELETE(AvFormatCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVFindBestStream::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + int32_t MediaTypeId, + int32_t WantedStream, + int32_t RelatedStream, + uint32_t DecoderRetId, int32_t Flags) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(DecoderRet, DecoderRetId, const AVCodec *); + + AVMediaType const AvMediaType = + FFmpegUtils::MediaType::intoMediaType(MediaTypeId); + return av_find_best_stream(AvFormatContext, AvMediaType, WantedStream, + RelatedStream, DecoderRet, Flags); +} + +Expect AVReadFrame::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, uint32_t PacketId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvPacket, PacketId, AVPacket); + + return av_read_frame(AvFormatContext, AvPacket); +} + +Expect AVIOClose::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + avio_close(AvFormatCtx->pb); + return static_cast(ErrNo::Success); +} + +Expect AVFormatNetworkInit::body(const Runtime::CallingFrame &) { + return avformat_network_init(); +} + +Expect AVFormatNetworkDeInit::body(const Runtime::CallingFrame &) { + return avformat_network_deinit(); +} + +Expect AVFormatWriteHeader::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t DictId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + return avformat_write_header(AvFormatContext, AvDict); +} + +Expect AVFormatWriteTrailer::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return av_write_trailer(AvFormatContext); +} + +Expect AVFormatAllocOutputContext2::body( + const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxPtr, + uint32_t AVOutputFormatId, uint32_t FormatNamePtr, uint32_t FormatLen, + uint32_t FileNamePtr, uint32_t FileNameLen) { + std::string Format; + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FileId, MemInst, char, FileNamePtr, + "Failed when accessing the return FileName memory"sv); + if (FormatLen > 0) { + MEM_PTR_CHECK(FormatId, MemInst, char, FormatNamePtr, + "Failed when accessing the return FormatName memory"sv); + + std::copy_n(FormatId, FormatLen, std::back_inserter(Format)); + } + MEM_PTR_CHECK(AvFormatCtxId, MemInst, uint32_t, AvFormatCtxPtr, + "Failed when accessing the return AVFormatContext Memory"sv); + + std::string File; + std::copy_n(FileId, FileNameLen, std::back_inserter(File)); + + AVFormatContext *AvFormatContext = nullptr; + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + int Res = 0; + if (FormatLen == 0) { + Res = avformat_alloc_output_context2(&AvFormatContext, AvOutputFormat, + nullptr, File.c_str()); + } else { + Res = avformat_alloc_output_context2(&AvFormatContext, AvOutputFormat, + Format.c_str(), File.c_str()); + } + FFMPEG_PTR_STORE(AvFormatContext, AvFormatCtxId); + return Res; +} + +Expect AVIOOpen::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t FileNamePtr, + uint32_t FileNameLen, int32_t Flags) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FileId, MemInst, char, FileNamePtr, + "Failed when accessing the return FileName memory"sv); + + std::string File; + std::copy_n(FileId, FileNameLen, std::back_inserter(File)); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + return avio_open(&(AvFormatContext->pb), File.c_str(), Flags); +} + +Expect AVIOOpen2::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxtId, uint32_t UrlPtr, + uint32_t UrlLen, int32_t Flags, + uint32_t AVIOInterruptCBId, + uint32_t AVDictionaryId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(UrlId, MemInst, char, UrlPtr, + "Failed when accessing the return Url memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxtId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, AVDictionaryId, AVDictionary *); + FFMPEG_PTR_FETCH(AvIOInterruptCB, AVIOInterruptCBId, AVIOInterruptCB); + + std::string TargetUrl; + std::copy_n(UrlId, UrlLen, std::back_inserter(TargetUrl)); + + return avio_open2(&(AvFormatCtx->pb), TargetUrl.c_str(), Flags, + AvIOInterruptCB, AvDictionary); +} + +Expect AVFormatVersion::body(const Runtime::CallingFrame &) { + return avformat_version(); +} + +Expect AVChapterMallocz::body(const Runtime::CallingFrame &Frame, + uint32_t AVChapterPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvChapterId, MemInst, uint32_t, AVChapterPtr, + "Failed to access Memory for AVChapterPtr"sv) + + AVChapter *AvChapter = + static_cast(av_mallocz(sizeof(AVChapter))); + FFMPEG_PTR_STORE(AvChapter, AvChapterId); + return static_cast(ErrNo::Success); +} + +Expect AVChapterDynarrayAdd::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + int32_t NbChaptersPtr, + uint32_t AvChapterId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(NbChapters, MemInst, int32_t, NbChaptersPtr, + "Failed to access Memory for NbChaptersPtr"sv) + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvChapter, AvChapterId, AVChapter); + + av_dynarray_add(&(AvFormatContext->chapters), NbChapters, AvChapter); + if (*(AvFormatContext->chapters) == nullptr && *(NbChapters) == 0) { + return static_cast(ErrNo::InternalError); + } + return static_cast(ErrNo::Success); +} + +Expect AVFreeP::body(const Runtime::CallingFrame &, + uint32_t AvChapterId) { + FFMPEG_PTR_FETCH(AvChapter, AvChapterId, AVChapter); + av_freep(AvChapter); + FFMPEG_PTR_DELETE(AvChapterId); + return static_cast(ErrNo::Success); +} + +Expect AVInterleavedWriteFrame::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + + return av_interleaved_write_frame(AvFormatCtx, AvPacket); +} + +Expect AVWriteFrame::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvPacketId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + + return av_write_frame(AvFormatCtx, AvPacket); +} + +Expect AVFormatNewStream::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvCodecId) { + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVStream *Stream = avformat_new_stream(AvFormatCtx, AvCodec); + if (Stream == nullptr) { + return 0; + } + return 1; +} + +Expect AVGuessCodec::body(const Runtime::CallingFrame &Frame, + uint32_t AVIOFormatId, + uint32_t ShortNamePtr, + uint32_t ShortNameLen, uint32_t FileNamePtr, + uint32_t FileNameLen, uint32_t MimeTypePtr, + uint32_t MimeTypeLen, int32_t MediaTypeId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(ShortNameBuf, MemInst, char, ShortNamePtr, + "Failed when accessing the return ShortName memory"sv); + MEM_PTR_CHECK(FileNameBuf, MemInst, char, FileNamePtr, + "Failed when accessing the return FileName memory"sv); + MEM_PTR_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, + "Failed when accessing the return MimeType memory"sv); + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + + std::string ShortName; + std::string FileName; + std::string MimeType; + std::copy_n(ShortNameBuf, ShortNameLen, std::back_inserter(ShortName)); + std::copy_n(FileNameBuf, FileNameLen, std::back_inserter(FileName)); + std::copy_n(MimeTypeBuf, MimeTypeLen, std::back_inserter(MimeType)); + + AVMediaType const MediaType = + FFmpegUtils::MediaType::intoMediaType(MediaTypeId); + AVCodecID const Id = + av_guess_codec(AvOutputFormat, ShortName.c_str(), FileName.c_str(), + MimeType.c_str(), MediaType); + + return FFmpegUtils::CodecID::fromAVCodecID(Id); +} + +Expect +AVFormatConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avformat_configuration(); + return strlen(Config); +} + +Expect AVFormatConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avformat_configuration(); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFormatLicenseLength::body(const Runtime::CallingFrame &) { + const char *License = avformat_license(); + return strlen(License); +} + +Expect AVFormatLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avformat_license(); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.h b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h new file mode 100644 index 00000000..533017b6 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVFormatOpenInput : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr, uint32_t UrlPtr, + uint32_t UrlSize, uint32_t AvInputFormatId, + uint32_t AvDictionaryId); +}; + +class AVFormatFindStreamInfo : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t AvDictionaryId); +}; + +class AVFormatCloseInput : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVReadPause : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId); +}; + +class AVReadPlay : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId); +}; + +class AVFormatSeekFile : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t StreamIdx, int64_t MinTs, int64_t Ts, + int64_t MaxTs, int32_t Flags); +}; + +class AVDumpFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, int32_t Idx, uint32_t UrlPtr, + uint32_t UrlSize, int32_t IsOutput); +}; + +class AVFormatFreeContext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr); +}; + +class AVFindBestStream : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + int32_t MediaTypeId, int32_t WantedStream, + int32_t RelatedStream, uint32_t DecoderRetId, + int32_t Flags); +}; + +class AVReadFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t PacketId); +}; + +class AVIOClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatNetworkInit : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatNetworkDeInit : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatWriteHeader : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t DictId); +}; + +class AVFormatWriteTrailer : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatAllocOutputContext2 + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr, uint32_t AVOutputFormatId, + uint32_t FormatNamePtr, uint32_t FormatLen, + uint32_t FileNamePtr, uint32_t FileNameLen); +}; + +class AVIOOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t FileNamePtr, + uint32_t FileNameLen, int32_t Flags); +}; + +class AVIOOpen2 : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxtId, uint32_t UrlPtr, + uint32_t UrlLen, int32_t Flags, + uint32_t AVIOInterruptCBId, uint32_t AVDictionaryId); +}; + +class AVFormatVersion : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVChapterMallocz : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVChapterPtr); +}; + +class AVChapterDynarrayAdd : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + int32_t NbChaptersPtr, uint32_t AvChapterId); +}; + +class AVFreeP : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvChapterId); +}; + +class AVInterleavedWriteFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t AvPacketId); +}; + +class AVWriteFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t AvPacketId); +}; + +class AVFormatNewStream : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, uint32_t AVCodecId); +}; + +class AVGuessCodec : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputOutputId, uint32_t ShortNamePtr, + uint32_t ShortNameLen, uint32_t FileNamePtr, + uint32_t FileNameLen, uint32_t MimeTypePtr, + uint32_t MimeTypeLen, int32_t MediaTypeId); +}; + +class AVFormatConfigurationLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatConfiguration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVFormatLicenseLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatLicense : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/module.cpp b/plugins/wasmedge_ffmpeg/avformat/module.cpp new file mode 100644 index 00000000..1246beff --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/module.cpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "avChapter.h" +#include "avInputOutputFormat.h" +#include "avStream.h" +#include "avformatContext.h" +#include "avformat_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +WasmEdgeFFmpegAVFormatModule::WasmEdgeFFmpegAVFormatModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avformat") { + // avformat_func.h + addHostFunc("wasmedge_ffmpeg_avformat_avformat_open_input", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_find_stream_info", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_close_input", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_read_play", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_read_pause", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_dump_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_seek_file", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_free_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_find_best_stream", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_read_frame", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avio_close", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_network_init", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_network_deinit", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_write_header", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_write_trailer", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_alloc_output_context2", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avio_open", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avio_open2", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avchapter_mallocz", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avformat_avchapter_dynarray_add", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_avfreep", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_write_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_interleaved_write_frame", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avformat_avformat_new_stream", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_guess_codec", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_license", + std::make_unique(Env)); + + // avformatContext Struct functions. + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_iformat", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_oformat", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_probescope", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_nb_streams", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_bit_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_nb_chapters", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_set_nb_chapters", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_set_metadata", + std::make_unique(Env)); + + // avInputFormat Struct functions. + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_long_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_long_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_long_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_extensions_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_extensions", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_extensions", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_mime_type_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_mime_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_mime_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputOutputFormat_free", + std::make_unique(Env)); + + // avStream Struct Functions. + addHostFunc("wasmedge_ffmpeg_avformat_avStream_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_index", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_codecpar", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_start_time", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_nb_frames", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_disposition", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_r_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_r_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_avg_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_avg_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_discard", + std::make_unique(Env)); + + // avChapter Struct Functions. + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_start", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_start", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_end", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_end", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_metadata", + std::make_unique(Env)); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/module.h b/plugins/wasmedge_ffmpeg/avformat/module.h new file mode 100644 index 00000000..47b604d9 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/module.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class WasmEdgeFFmpegAVFormatModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVFormatModule(std::shared_ptr Env); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp new file mode 100644 index 00000000..486916c2 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -0,0 +1,163 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avDictionary.h" + +extern "C" { +#include "libavutil/dict.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVDictSet::body(const Runtime::CallingFrame &Frame, + uint32_t DictPtr, uint32_t KeyPtr, + uint32_t KeyLen, uint32_t ValuePtr, + uint32_t ValueLen, int32_t Flags) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(KeyBuf, MemInst, char, KeyPtr, + "Failed when accessing the return Key memory"sv); + MEM_PTR_CHECK(ValueBuf, MemInst, char, ValuePtr, + "Failed when accessing the return Value memory"sv); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed to access Memory for AVDict"sv) + + std::string Key; + std::string Value; + std::copy_n(KeyBuf, KeyLen, std::back_inserter(Key)); + std::copy_n(ValueBuf, ValueLen, std::back_inserter(Value)); + + int Res = 0; + + // Using Maybe::uninit(); in Rust. If Uninitialized, zero is + // passed. Else the Ptr contains a Number. + if (*DictId) { + FFMPEG_PTR_FETCH(AvDict, *DictId, AVDictionary *); + Res = av_dict_set(AvDict, Key.c_str(), Value.c_str(), Flags); + } else { + AVDictionary **AvDict = + static_cast(av_mallocz(sizeof(AVDictionary *))); + Res = av_dict_set(AvDict, Key.c_str(), Value.c_str(), Flags); + FFMPEG_PTR_STORE(AvDict, DictId); + } + + return Res; +} + +Expect AVDictCopy::body(const Runtime::CallingFrame &Frame, + uint32_t DestDictPtr, uint32_t SrcDictId, + uint32_t Flags) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DestDictId, MemInst, uint32_t, DestDictPtr, + "Failed to access Memory for AVDict"sv) + + FFMPEG_PTR_FETCH(SrcAvDict, SrcDictId, AVDictionary *); + + int Res = 0; + + if (SrcAvDict == nullptr) { + return static_cast(ErrNo::InternalError); + } + + if (*DestDictId) { + FFMPEG_PTR_FETCH(DestAvDict, *DestDictId, AVDictionary *); + Res = av_dict_copy(DestAvDict, *SrcAvDict, Flags); + } else { + AVDictionary **DestAvDict = + static_cast(av_mallocz(sizeof(AVDictionary *))); + av_dict_copy(DestAvDict, *SrcAvDict, Flags); + FFMPEG_PTR_STORE(DestAvDict, DestDictId); + } + + return Res; +} + +Expect AVDictGet::body(const Runtime::CallingFrame &Frame, + uint32_t DictId, uint32_t KeyPtr, + uint32_t KeyLen, uint32_t PrevDictEntryIdx, + uint32_t Flags, uint32_t KeyLenPtr, + uint32_t ValueLenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(KeyStr, MemInst, char, KeyPtr, + "Failed when accessing the return Key memory"sv); + MEM_PTR_CHECK(KeyLenId, MemInst, uint32_t, KeyLenPtr, + "Failed when accessing the return KeyLen memory"sv); + MEM_PTR_CHECK(ValueLenId, MemInst, uint32_t, ValueLenPtr, + "Failed when accessing the return ValueLen memory"sv); + + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + + // Return if Dict was not created (i.e. 0 is passed as AVDictId). + if (AvDict == nullptr) { + return static_cast(ErrNo::InternalError); + } + std::string Key; + std::copy_n(KeyStr, KeyLen, std::back_inserter(Key)); + + AVDictionaryEntry *DictEntry = nullptr; + uint32_t Curr = 0; + while (Curr <= PrevDictEntryIdx) { + DictEntry = av_dict_get(*AvDict, Key.c_str(), DictEntry, Flags); + Curr++; + } + + if (DictEntry == nullptr) { + return static_cast(ErrNo::InternalError); + } + + *KeyLenId = strlen(DictEntry->key); + *ValueLenId = strlen(DictEntry->value); + return Curr; +} + +Expect AVDictGetKeyValue::body( + const Runtime::CallingFrame &Frame, uint32_t DictId, uint32_t KeyPtr, + uint32_t KeyLen, uint32_t ValBufPtr, uint32_t ValBufLen, uint32_t KeyBufPtr, + uint32_t KeyBufLen, uint32_t PrevDictEntryIdx, uint32_t Flags) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(KeyStr, MemInst, char, KeyPtr, + "Failed when accessing the return Key memory"sv); + MEM_SPAN_CHECK(KeyBuf, MemInst, char, KeyBufPtr, KeyBufLen, ""); + MEM_SPAN_CHECK(ValBuf, MemInst, char, ValBufPtr, ValBufLen, ""); + + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + + // Return if Dict was not created (i.e. 0 is passed as AVDictId). + if (AvDict == nullptr) { + return static_cast(ErrNo::InternalError); + } + + std::string Key; + std::copy_n(KeyStr, KeyLen, std::back_inserter(Key)); + + AVDictionaryEntry *DictEntry = nullptr; + uint32_t Curr = 0; + while (Curr <= PrevDictEntryIdx) { + DictEntry = av_dict_get(*AvDict, Key.c_str(), DictEntry, Flags); + Curr++; + } + if (DictEntry == nullptr) { + return static_cast(ErrNo::InternalError); + } + std::copy_n(DictEntry->value, strlen(DictEntry->value), ValBuf.data()); + std::copy_n(DictEntry->key, strlen(DictEntry->key), KeyBuf.data()); + return Curr; +} + +Expect AVDictFree::body(const Runtime::CallingFrame &, + uint32_t DictId) { + if (DictId == 0) { + return static_cast(ErrNo::Success); + } + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + av_dict_free(AvDict); + FFMPEG_PTR_DELETE(DictId); + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.h b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h new file mode 100644 index 00000000..c1731a33 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVDictSet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, + uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValuePtr, + uint32_t ValueLen, int32_t Flags); +}; + +class AVDictGet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, + uint32_t KeyPtr, uint32_t KeyLen, + uint32_t PrevDictEntryIdx, uint32_t Flags, + uint32_t KeyLenPtr, uint32_t ValueLenPtr); +}; + +class AVDictGetKeyValue : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, + uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValBufPtr, + uint32_t ValBufLen, uint32_t KeyBufPtr, + uint32_t KeyBufLen, uint32_t PrevDictEntryIdx, + uint32_t Flags); +}; + +class AVDictCopy : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestDictId, + uint32_t SrcDictId, uint32_t Flags); +}; + +class AVDictFree : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp new file mode 100644 index 00000000..d63cd248 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avFrame.h" + +extern "C" { +#include "libavutil/frame.h" +#include "libavutil/pixfmt.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVFrameAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t FramePtr) { + MEMINST_CHECK(MemInst, Frame, 0) + MEM_PTR_CHECK(FrameId, MemInst, uint32_t, FramePtr, + "Failed to access Memory for AVFrame"sv) + + AVFrame *AvFrame = av_frame_alloc(); + FFMPEG_PTR_STORE(AvFrame, FrameId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameFree::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + av_frame_free(&AvFrame); + FFMPEG_PTR_DELETE(FrameId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameWidth::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->width; +} + +Expect AVFrameHeight::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->height; +} + +Expect AVFrameSetHeight::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t Height) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->height = Height; + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetWidth::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t Width) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->width = Width; + return static_cast(ErrNo::Success); +} + +Expect AVFrameVideoFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + int const Format = AvFrame->format; + if (Format == -1) { + return -1; + } + AVPixelFormat const PixelFormat = static_cast(Format); + return FFmpegUtils::PixFmt::fromAVPixFmt(PixelFormat); +} + +Expect AVFrameSetVideoFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId, + uint32_t AvPixFormatId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(AvPixFormatId); + AvFrame->format = PixelFormat; + return static_cast(ErrNo::Success); +} + +Expect AVFrameIsNull::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->data[0] == nullptr; +} + +Expect AVFrameLinesize::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t Idx) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->linesize[Idx]; +} + +Expect AVFrameData::body(const Runtime::CallingFrame &Frame, + uint32_t FrameId, uint32_t FrameBufPtr, + uint32_t FrameBufLen, uint32_t Index) { + MEMINST_CHECK(MemInst, Frame, 0) + MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, FrameBufPtr, FrameBufLen, ""); + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + + uint8_t *Data = AvFrame->data[Index]; + std::copy_n(Data, FrameBufLen, Buffer.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFrameGetBuffer::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t Align) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return av_frame_get_buffer(AvFrame, Align); +} + +Expect AVFrameAudioFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + int const Format = AvFrame->format; + if (Format == -1) { + return -1; + } + + AVSampleFormat const SampleFormat = static_cast(Format); + return FFmpegUtils::SampleFmt::toSampleID(SampleFormat); +} + +Expect AVFrameSetAudioFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId, + uint32_t SampleFormatId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVSampleFormat const SampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + AvFrame->format = SampleFormat; + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetChannelLayout::body(const Runtime::CallingFrame &, + uint32_t FrameId, + uint64_t ChannelLayoutID) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutID); + av_channel_layout_from_mask(&AvFrame->ch_layout, ChannelLayout); + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetNbSamples::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t Samples) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->nb_samples = Samples; + return static_cast(ErrNo::Success); +} + +Expect AVFrameNbSamples::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->nb_samples; +} + +Expect AVFrameSampleRate::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->sample_rate; +} + +Expect AVFrameSetSampleRate::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t SampleRate) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->sample_rate = SampleRate; + return static_cast(ErrNo::Success); +} + +Expect AVFrameChannels::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->ch_layout.nb_channels; +} + +Expect AVFrameSetChannels::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t Channels) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->ch_layout.nb_channels = Channels; + return static_cast(ErrNo::Success); +} + +Expect AVFrameChannelLayout::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + uint64_t const ChannelLayout = AvFrame->ch_layout.u.mask; + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(ChannelLayout); +} + +Expect AVFrameBestEffortTimestamp::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->best_effort_timestamp; +} + +Expect AVFramePictType::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVPictureType const AvPictureType = AvFrame->pict_type; + return FFmpegUtils::PictureType::fromAVPictureType(AvPictureType); +} + +Expect AVFrameSetPictType::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t PictureId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVPictureType const AvPictureType = + FFmpegUtils::PictureType::intoAVPictureType(PictureId); + + AvFrame->pict_type = AvPictureType; + return static_cast(ErrNo::Success); +} + +Expect AVFrameInterlacedFrame::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->interlaced_frame; +} + +Expect AVFrameTopFieldFirst::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->top_field_first; +} + +Expect AVFramePaletteHasChanged::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->palette_has_changed; +} + +Expect AVFrameColorSpace::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorSpace const AvColorSpace = AvFrame->colorspace; + return FFmpegUtils::ColorSpace::fromAVColorSpace(AvColorSpace); +} + +Expect AVFrameSetColorSpace::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t ColorSpaceId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->colorspace = FFmpegUtils::ColorSpace::intoAVColorSpace(ColorSpaceId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameColorRange::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorRange const AvColorRange = AvFrame->color_range; + + return static_cast(AvColorRange); +} + +Expect AVFrameSetColorRange::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t ColorRangeId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->color_range = static_cast(ColorRangeId); + return static_cast(ErrNo::Success); +} + +Expect +AVFrameColorTransferCharacteristic::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorTransferCharacteristic const Characteristic = AvFrame->color_trc; + + // The binding can be used as well. Currently, the binding is commented out. + return static_cast(Characteristic); +} + +Expect AVFrameSetColorTransferCharacteristic::body( + const Runtime::CallingFrame &, uint32_t FrameId, + int32_t ColorTransferCharacteristicId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->color_trc = + static_cast(ColorTransferCharacteristicId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameChromaLocation::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVChromaLocation const AvChromaLocation = AvFrame->chroma_location; + return FFmpegUtils::ChromaLocation::fromAVChromaLocation(AvChromaLocation); +} + +Expect AVFrameRepeatPict::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->repeat_pict; +} + +Expect AVFrameFlags::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->flags; +} + +Expect AVFrameQuality::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->quality; +} + +Expect AVFrameMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t FrameId, uint32_t DictPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictionary memory"sv); + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + + *AvDictionary = AvFrame->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetMetadata::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t DictId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + + if (AvDict == nullptr) { + AvFrame->metadata = nullptr; + } else { + AvFrame->metadata = *AvDict; + } + return static_cast(ErrNo::Success); +} + +Expect AVFrameKeyFrame::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->key_frame; +} + +Expect AVFramePts::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->pts; +} + +Expect AVFrameSetPts::body(const Runtime::CallingFrame &, + uint32_t FrameId, int64_t Pts) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->pts = Pts; + return static_cast(ErrNo::Success); +} + +Expect AVFrameCopy::body(const Runtime::CallingFrame &, + uint32_t DestFrameId, uint32_t SrcFrameId) { + FFMPEG_PTR_FETCH(DestAvFrame, DestFrameId, AVFrame); + FFMPEG_PTR_FETCH(SrcAvFrame, SrcFrameId, AVFrame); + + av_frame_copy(DestAvFrame, SrcAvFrame); + return static_cast(ErrNo::Success); +} + +Expect AVFrameCopyProps::body(const Runtime::CallingFrame &, + uint32_t DestFrameId, + uint32_t SrcFrameId) { + FFMPEG_PTR_FETCH(DestAvFrame, DestFrameId, AVFrame); + FFMPEG_PTR_FETCH(SrcAvFrame, SrcFrameId, AVFrame); + + av_frame_copy_props(DestAvFrame, SrcAvFrame); + return static_cast(ErrNo::Success); +} + +Expect +AVFrameSampleAspectRatio::body(const Runtime::CallingFrame &Frame, + uint32_t FrameId, uint32_t NumPtr, + uint32_t DenPtr) { + MEMINST_CHECK(MemInst, Frame, 0) + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const Rational = AvFrame->sample_aspect_ratio; + *Num = Rational.num; + *Den = Rational.den; + return static_cast(ErrNo::Success); +} + +Expect AVFrameColorPrimaries::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorPrimaries const ColorPrimaries = AvFrame->color_primaries; + return FFmpegUtils::ColorPrimaries::fromAVColorPrimaries(ColorPrimaries); +} + +Expect AVFrameSetColorPrimaries::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t ColorPrimariesId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorPrimaries const ColorPrimaries = + FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); + AvFrame->color_primaries = ColorPrimaries; + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.h b/plugins/wasmedge_ffmpeg/avutil/avFrame.h new file mode 100644 index 00000000..f8c87458 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.h @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVFrameAlloc : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FramePtr); +}; + +class AVFrameFree : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameWidth : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameHeight : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetWidth : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t Width); +}; + +class AVFrameSetHeight : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t Height); +}; + +class AVFrameVideoFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetVideoFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t AvPixFormatId); +}; + +class AVFrameIsNull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameLinesize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t Idx); +}; + +class AVFrameData : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t FrameBufPtr, uint32_t FrameBufLen, + uint32_t Index); +}; + +class AVFrameGetBuffer : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t Align); +}; + +class AVFrameAudioFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetAudioFormat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t SampleFormatId); +}; + +class AVFrameSetChannelLayout : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint64_t ChannelLayoutID); +}; + +class AVFrameSetNbSamples : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t Samples); +}; + +class AVFrameNbSamples : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSampleRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetSampleRate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t SampleRate); +}; + +class AVFrameChannels : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetChannels : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t Channels); +}; + +class AVFrameChannelLayout : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameBestEffortTimestamp + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFramePictType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetPictType : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t PictureId); +}; + +class AVFrameInterlacedFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameTopFieldFirst : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFramePaletteHasChanged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameColorSpace : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorSpace : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorSpaceId); +}; + +class AVFrameColorRange : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorRange : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorRangeId); +}; + +// color_transfer_characteristic + +class AVFrameColorTransferCharacteristic + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorTransferCharacteristic + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorTransferCharacteristicId); +}; + +class AVFrameChromaLocation : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameRepeatPict : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameQuality : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t DictPtr); +}; + +class AVFrameSetMetadata : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t DictId); +}; + +class AVFrameKeyFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFramePts : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetPts : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int64_t Pts); +}; + +class AVFrameCopy : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestFrameId, + uint32_t SrcFrameId); +}; + +class AVFrameCopyProps : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestFrameId, + uint32_t SrcFrameId); +}; + +class AVFrameSampleAspectRatio : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVFrameColorPrimaries : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorPrimaries : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorPrimariesId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avRational.cpp b/plugins/wasmedge_ffmpeg/avutil/avRational.cpp new file mode 100644 index 00000000..deedd7e4 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avRational.cpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avRational.h" + +extern "C" { +#include "libavutil/rational.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVAddQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_add_q(A, B); + *CNum = C.num; + *CDen = C.den; + + return static_cast(ErrNo::Success); +} + +Expect AVSubQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_sub_q(A, B); + *CNum = C.num; + *CDen = C.den; + return static_cast(ErrNo::Success); +} + +Expect AVMulQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_mul_q(A, B); + *CNum = C.num; + *CDen = C.den; + return static_cast(ErrNo::Success); +} + +Expect AVDivQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_div_q(A, B); + *CNum = C.num; + *CDen = C.den; + return static_cast(ErrNo::Success); +} + +Expect AVCmpQ::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen) { + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + return av_cmp_q(A, B); +} + +Expect AVNearerQ::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + int32_t CNum, int32_t CDen) { + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + AVRational const C = av_make_q(CNum, CDen); + + return av_nearer_q(A, B, C); +} + +Expect AVQ2d::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen) { + AVRational const A = av_make_q(ANum, ADen); + return av_q2d(A); +} + +Expect AVD2Q::body(const Runtime::CallingFrame &Frame, double_t D, + int32_t Max, uint32_t ANumPtr, uint32_t ADenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(ANum, MemInst, int32_t, ANumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(ADen, MemInst, int32_t, ADenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_d2q(D, Max); + *ANum = A.num; + *ADen = A.den; + return static_cast(ErrNo::Success); +} + +Expect AVQ2IntFloat::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen) { + AVRational const A = av_make_q(ANum, ADen); + return av_q2intfloat(A); +} + +Expect AVInvQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, uint32_t BNumPtr, uint32_t BDenPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(BNum, MemInst, int32_t, BNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(BDen, MemInst, int32_t, BDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_inv_q(A); + + *BNum = B.num; + *BDen = B.den; + return static_cast(ErrNo::Success); +} + +Expect AVReduce::body(const Runtime::CallingFrame &Frame, + uint32_t ANumPtr, uint32_t ADenPtr, int64_t BNum, + int64_t BDen, int64_t Max) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(ANum, MemInst, int32_t, ANumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(ADen, MemInst, int32_t, ADenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + return av_reduce(ANum, ADen, BNum, BDen, Max); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avRational.h b/plugins/wasmedge_ffmpeg/avutil/avRational.h new file mode 100644 index 00000000..82866940 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avRational.h @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVAddQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVSubQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVMulQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVDivQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVCmpQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen); +}; + +class AVNearerQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, int32_t CNum, + int32_t CDen); +}; + +class AVQ2d : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen); +}; + +class AVD2Q : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, double_t D, + int32_t Max, uint32_t ANumPtr, uint32_t ADenPtr); +}; + +class AVQ2IntFloat : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen); +}; + +class AVInvQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, uint32_t BNumPtr, uint32_t BDenPtr); +}; + +class AVReduce : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ANumPtr, + uint32_t ADenPtr, int64_t BNum, int64_t BDen, + int64_t Max); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avTime.cpp b/plugins/wasmedge_ffmpeg/avutil/avTime.cpp new file mode 100644 index 00000000..40cdaba1 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avTime.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avTime.h" + +extern "C" { +#include "libavutil/time.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVGetTime::body(const Runtime::CallingFrame &) { + return av_gettime(); +} + +Expect AVGetTimeRelative::body(const Runtime::CallingFrame &) { + return av_gettime_relative(); +} + +Expect +AVGetTimeRelativeIsMonotonic::body(const Runtime::CallingFrame &) { + return av_gettime_relative_is_monotonic(); +} + +Expect AVUSleep::body(const Runtime::CallingFrame &, uint32_t USec) { + return av_usleep(USec); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avTime.h b/plugins/wasmedge_ffmpeg/avutil/avTime.h new file mode 100644 index 00000000..6ec6e2c6 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avTime.h @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVGetTime : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVGetTimeRelative : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVGetTimeRelativeIsMonotonic + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVUSleep : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t USec); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp new file mode 100644 index 00000000..da812afc --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil_func.h" + +extern "C" { +#include "libavutil/avutil.h" +#include "libavutil/time.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVLogSetLevel::body(const Runtime::CallingFrame &, + int32_t LogLevelId) { + av_log_set_level(LogLevelId); + return {}; +} + +Expect AVLogGetLevel::body(const Runtime::CallingFrame &) { + return av_log_get_level(); +} + +Expect AVLogGetFlags::body(const Runtime::CallingFrame &) { + return av_log_get_flags(); +} + +Expect AVLogSetFlags::body(const Runtime::CallingFrame &, + int32_t FlagId) { + av_log_set_flags(FlagId); + return {}; +} + +Expect AVRescaleQ::body(const Runtime::CallingFrame &, int64_t A, + int32_t BNum, int32_t BDen, int32_t CNum, + int32_t CDen) { + AVRational const B = av_make_q(BNum, BDen); + AVRational const C = av_make_q(CNum, CDen); + return av_rescale_q(A, B, C); +} + +Expect AVRescaleQRnd::body(const Runtime::CallingFrame &, int64_t A, + int32_t BNum, int32_t BDen, int32_t CNum, + int32_t CDen, int32_t RoundingId) { + AVRational const B = av_make_q(BNum, BDen); + AVRational const C = av_make_q(CNum, CDen); + AVRounding const Rounding = FFmpegUtils::Rounding::intoAVRounding(RoundingId); + return av_rescale_q_rnd(A, B, C, Rounding); +} + +Expect AVUtilVersion::body(const Runtime::CallingFrame &) { + return avutil_version(); +} + +Expect +AVGetChannelLayoutNbChannels::body(const Runtime::CallingFrame &, + uint64_t ChannelLayoutId) { + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + + AVChannelLayout TmpChLayout; + av_channel_layout_from_mask(&TmpChLayout, ChannelLayout); + int32_t ChannelLayoutNbChannels = TmpChLayout.nb_channels; + av_channel_layout_uninit(&TmpChLayout); + + return ChannelLayoutNbChannels; +} + +Expect AVGetChannelLayoutNameLen::body(const Runtime::CallingFrame &, + uint64_t ChannelLayoutId) { + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + char ChName[16] = {0}; // bufsize based on AVChannelCustom.name + // mask ChannelLayout to AVChannel before passing + int Code = + av_channel_name(ChName, 16, static_cast(ChannelLayout >> 1)); + + if (Code < 0) { + return 0; + } + return strlen(ChName); +} + +Expect AVGetChannelLayoutName::body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId, + uint32_t NamePtr, + uint32_t NameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + char ChName[16] = {0}; // bufsize based on AVChannelCustom.name + // mask ChannelLayout to AVChannel before passing + av_channel_name(ChName, 16, static_cast(ChannelLayout >> 1)); + + std::copy_n(ChName, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVGetChannelLayoutMask::body(const Runtime::CallingFrame &, + uint64_t ChannelLayoutId) { + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + return ChannelLayout; +} + +Expect AVGetDefaultChannelLayout::body(const Runtime::CallingFrame &, + int32_t Number) { + AVChannelLayout TmpChLayout; + av_channel_layout_default(&TmpChLayout, Number); + uint64_t DefaultChannelLayout = + FFmpegUtils::ChannelLayout::intoChannelLayoutID(TmpChLayout.u.mask); + av_channel_layout_uninit(&TmpChLayout); + + return DefaultChannelLayout; +} + +Expect AVUtilConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avutil_configuration(); + return strlen(Config); +} + +Expect AVUtilConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avutil_configuration(); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVUtilLicenseLength::body(const Runtime::CallingFrame &) { + const char *License = avutil_license(); + return strlen(License); +} + +Expect AVUtilLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, uint32_t LicenseLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avutil_license(); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h new file mode 100644 index 00000000..ebdee0e4 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVLogSetLevel : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t LogLevelId); +}; + +class AVLogGetLevel : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVLogSetFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t FlagsId); +}; + +class AVLogGetFlags : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +// Option functions. +class AVOptSetBin : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetInt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetDouble : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetImageSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetPixelFmt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetSampleFmt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetChannelLayout : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVRescaleQ : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int64_t A, + int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen); +}; + +class AVRescaleQRnd : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, int64_t A, int32_t BNum, + int32_t BDen, int32_t CNum, int32_t CDen, + int32_t RoundingId); +}; + +class AVUtilVersion : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &); +}; + +class AVGetChannelLayoutNbChannels + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId); +}; + +class AVGetChannelLayoutNameLen + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId); +}; + +class AVGetChannelLayoutName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId, uint32_t NamePtr, + uint32_t NameLen); +}; + +class AVGetChannelLayoutMask : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId); +}; + +class AVGetDefaultChannelLayout + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t ChannelLayoutId); +}; + +class AVUtilConfigurationLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVUtilConfiguration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVUtilLicenseLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVUtilLicense : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/error.cpp b/plugins/wasmedge_ffmpeg/avutil/error.cpp new file mode 100644 index 00000000..90df0fbf --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/error.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "error.h" + +extern "C" { +#include "libavutil/error.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVUtilAVStrError::body(const Runtime::CallingFrame &Frame, + int32_t ErrNum, uint32_t ErrBuf, + uint32_t BufLen) { + MEMINST_CHECK(MemInst, Frame, 0); + + MEM_PTR_CHECK(ErrId, MemInst, char, ErrBuf, + "Failed when accessing the return URL memory"sv); + + std::string Error; + std::copy_n(ErrId, BufLen, std::back_inserter(Error)); + return av_strerror(ErrNum, const_cast(Error.c_str()), BufLen); +} + +Expect AVUtilAVError::body(const Runtime::CallingFrame &, + int32_t ErrNum) { + return AVERROR(ErrNum); +} + +Expect AVUtilAVUNError::body(const Runtime::CallingFrame &, + int32_t ErrNum) { + return AVUNERROR(ErrNum); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/error.h b/plugins/wasmedge_ffmpeg/avutil/error.h new file mode 100644 index 00000000..0ff38ce5 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/error.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVUtilAVStrError : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum, + uint32_t ErrBuf, uint32_t BufLen); +}; + +class AVUtilAVError : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum); +}; + +class AVUtilAVUNError : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/module.cpp b/plugins/wasmedge_ffmpeg/avutil/module.cpp new file mode 100644 index 00000000..948ff3ec --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/module.cpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "avDictionary.h" +#include "avFrame.h" +#include "avRational.h" +#include "avTime.h" +#include "avutil_func.h" +#include "error.h" +#include "pixfmt.h" +#include "samplefmt.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +WasmEdgeFFmpegAVUtilModule::WasmEdgeFFmpegAVUtilModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avutil") { + // error.h + addHostFunc("wasmedge_ffmpeg_avutil_av_strerror", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_AVERROR", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_AVUNERROR", + std::make_unique(Env)); + + // rational.h + addHostFunc("wasmedge_ffmpeg_avutil_av_add_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_sub_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_mul_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_div_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_d2q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_q2d", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_inv_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_q2intfloat", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_nearer_q", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_cmp_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_reduce", + std::make_unique(Env)); + + // frame.h + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_video_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_video_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_isnull", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_linesize", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_data", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_get_buffer", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_audio_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_audio_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_nb_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_nb_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_best_effort_timestamp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_pict_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_pict_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_interlaced_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_top_field_first", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_palette_has_changed", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_color_trc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_color_trc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_chroma_location", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_repeat_pict", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_quality", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_key_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_copy", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_copy_props", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_sample_aspect_ratio", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_color_primaries", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_color_primaries", + std::make_unique(Env)); + + // pixfmt.h (Even AvPixFmtDesc is in this file) + addHostFunc("wasmedge_ffmpeg_avutil_avpixfmtdescriptor_nb_components", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromaw", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromah", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_transfer_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_transfer_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_range_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_range_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_space_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_space_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_primaries_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_primaries_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_pix_format_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_pix_format_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_pix_format_mask", + std::make_unique(Env)); + + // samplefmt.h + addHostFunc("wasmedge_ffmpeg_avutil_av_get_packed_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_planar_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_sample_fmt_is_planar", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_bytes_per_sample", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_samples_get_buffer_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_samples_alloc_array_and_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt_mask", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_freep", + std::make_unique(Env)); + + // dict.h + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_set", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_get", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_get_key_value", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_copy", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_free", + std::make_unique(Env)); + + // avutil_func.h + addHostFunc("wasmedge_ffmpeg_avutil_av_log_set_level", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_log_get_level", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_log_set_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_log_get_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_rescale_q", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_rescale_q_rnd", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_channel_layout_nb_channels", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_name_len", // TODO: Write + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_name", // TODO: Write Test + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_mask", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_default_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_license", + std::make_unique(Env)); + + // time.h + addHostFunc("wasmedge_ffmpeg_avutil_av_gettime", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_gettime_relative", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_gettime_relative_is_monotonic", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_usleep", + std::make_unique(Env)); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/module.h b/plugins/wasmedge_ffmpeg/avutil/module.h new file mode 100644 index 00000000..6c22537b --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/module.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class WasmEdgeFFmpegAVUtilModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVUtilModule(std::shared_ptr Env); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp new file mode 100644 index 00000000..8886459f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "pixfmt.h" + +extern "C" { +#include "libavutil/pixdesc.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect +AvPixFmtDescriptorNbComponents::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *AvPixFmtDescriptor = + av_pix_fmt_desc_get(PixelFormat); + return AvPixFmtDescriptor->nb_components; +} + +Expect +AvPixFmtDescriptorLog2ChromaW::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *AvPixFmtDescriptor = + av_pix_fmt_desc_get(PixelFormat); + return AvPixFmtDescriptor->log2_chroma_w; +} + +Expect +AvPixFmtDescriptorLog2ChromaH::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *AvPixFmtDescriptor = + av_pix_fmt_desc_get(PixelFormat); + return AvPixFmtDescriptor->log2_chroma_h; +} + +Expect AVColorRangeNameLength::body(const Runtime::CallingFrame &, + int32_t RangeId) { + AVColorRange const ColorRange = static_cast(RangeId); + const char *Name = av_color_range_name(ColorRange); + return strlen(Name); +} + +Expect AVColorRangeName::body(const Runtime::CallingFrame &Frame, + int32_t RangeId, uint32_t RangeNamePtr, + uint32_t RangeLength) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(RangeNameBuf, MemInst, char, RangeNamePtr, RangeLength, ""); + + AVColorRange const ColorRange = static_cast(RangeId); + const char *RangeName = av_color_range_name(ColorRange); + auto Actual = std::strlen(RangeName); + auto N = std::min(RangeLength, static_cast(Actual + 1)); + std::copy_n(RangeName, N, RangeNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVColorTransferNameLength::body(const Runtime::CallingFrame &, + int32_t TransferId) { + AVColorTransferCharacteristic const Characteristic = + static_cast(TransferId); + const char *Name = av_color_transfer_name(Characteristic); + return strlen(Name); +} + +Expect AVColorTransferName::body(const Runtime::CallingFrame &Frame, + int32_t TransferId, + uint32_t TransferNamePtr, + uint32_t TransferLength) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(TransferNameBuf, MemInst, char, TransferNamePtr, + TransferLength, ""); + + AVColorTransferCharacteristic const Characteristic = + static_cast(TransferId); + const char *TransferName = av_color_transfer_name(Characteristic); + auto Actual = std::strlen(TransferName); + auto N = + std::min(TransferLength, static_cast(Actual + 1)); + std::copy_n(TransferName, N, TransferNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVColorSpaceNameLength::body(const Runtime::CallingFrame &, + int32_t ColorSpaceId) { + AVColorSpace const ColorSpace = static_cast(ColorSpaceId); + const char *Name = av_color_space_name(ColorSpace); + return strlen(Name); +} + +Expect AVColorSpaceName::body(const Runtime::CallingFrame &Frame, + int32_t ColorSpaceId, + uint32_t ColorSpaceNamePtr, + uint32_t ColorSpaceLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ColorSpaceBuf, MemInst, char, ColorSpaceNamePtr, ColorSpaceLen, + ""); + + AVColorSpace const ColorSpace = static_cast(ColorSpaceId); + const char *ColorSpaceName = av_color_space_name(ColorSpace); + auto Actual = std::strlen(ColorSpaceName); + auto N = std::min(ColorSpaceLen, static_cast(Actual + 1)); + std::copy_n(ColorSpaceName, N, ColorSpaceBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVColorPrimariesNameLength::body(const Runtime::CallingFrame &, + int32_t ColorPrimariesId) { + AVColorPrimaries const ColorPrimaries = + FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); + const char *Name = av_color_primaries_name(ColorPrimaries); + return strlen(Name); +} + +Expect AVColorPrimariesName::body(const Runtime::CallingFrame &Frame, + int32_t ColorPrimariesId, + uint32_t ColorPrimariesNamePtr, + uint32_t ColorPrimariesLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ColorPrimariesBuf, MemInst, char, ColorPrimariesNamePtr, + ColorPrimariesLen, ""); + + AVColorPrimaries const ColorPrimaries = + FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); + const char *PrimariesName = av_color_primaries_name(ColorPrimaries); + auto Actual = std::strlen(PrimariesName); + auto N = + std::min(ColorPrimariesLen, static_cast(Actual + 1)); + std::copy_n(PrimariesName, N, ColorPrimariesBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVPixelFormatNameLength::body(const Runtime::CallingFrame &, + uint32_t AvPixFormatId) { + AVPixelFormat const PixFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(AvPixFormatId); + const AVPixFmtDescriptor *PixFmtDescriptor = av_pix_fmt_desc_get(PixFormat); + + return strlen(PixFmtDescriptor->name); +} + +Expect AVPixelFormatName::body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId, + uint32_t PixFormatNamePtr, + uint32_t PixFormatNameLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(PixFormatBuf, MemInst, char, PixFormatNamePtr, + PixFormatNameLen, ""); + + AVPixelFormat const PixFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *PixFmtDescriptor = av_pix_fmt_desc_get(PixFormat); + const char *PixFormatName = PixFmtDescriptor->name; + auto Actual = std::strlen(PixFormatName); + auto N = + std::min(PixFormatNameLen, static_cast(Actual + 1)); + std::copy_n(PixFormatName, N, PixFormatBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVPixelFormatMask::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return static_cast(PixelFormat); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.h b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h new file mode 100644 index 00000000..0b5dfc64 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AvPixFmtDescriptorNbComponents + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class AvPixFmtDescriptorLog2ChromaW + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class AvPixFmtDescriptorLog2ChromaH + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class AVColorRangeNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t RangeId); +}; + +class AVColorRangeName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t RangeId, + uint32_t RangeName, uint32_t RangeLength); +}; + +class AVColorTransferNameLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t TransferId); +}; + +class AVColorTransferName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t TransferId, + uint32_t TransferNamePtr, uint32_t TransferLength); +}; + +class AVColorSpaceNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t ColorSpaceId); +}; + +class AVColorSpaceName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t ColorSpaceId, + uint32_t ColorSpaceNamePtr, uint32_t ColorSpaceLen); +}; + +class AVColorPrimariesNameLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t ColorPrimariesId); +}; + +class AVColorPrimariesName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + int32_t ColorPrimariesId, uint32_t ColorPrimariesNamePtr, + uint32_t ColorPrimariesLen); +}; + +class AVPixelFormatNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvPixFormatId); +}; + +class AVPixelFormatName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId, + uint32_t PixFormatNamePtr, uint32_t PixFormatNameLen); +}; + +class AVPixelFormatMask : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp b/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp new file mode 100644 index 00000000..40150a68 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "samplefmt.h" + +extern "C" { +#include "libavutil/samplefmt.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVGetPlanarSampleFmt::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + AVSampleFormat const PlanarSampleFmt = + av_get_planar_sample_fmt(AvSampleFormat); + return FFmpegUtils::SampleFmt::toSampleID(PlanarSampleFmt); +} + +Expect AVGetPackedSampleFmt::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + AVSampleFormat const PackedSampleFmt = + av_get_packed_sample_fmt(AvSampleFormat); + return FFmpegUtils::SampleFmt::toSampleID(PackedSampleFmt); +} + +Expect AVSampleFmtIsPlanar::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + return av_sample_fmt_is_planar(AvSampleFormat); +} + +Expect AVGetBytesPerSample::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + return av_get_bytes_per_sample(AvSampleFormat); +} + +Expect AVGetSampleFmt::body(const Runtime::CallingFrame &Frame, + uint32_t Str, uint32_t StrLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(StrId, MemInst, char, Str, ""); + + std::string TargetUrl; + std::copy_n(StrId, StrLen, std::back_inserter(TargetUrl)); + + AVSampleFormat const AvSampleFormat = av_get_sample_fmt(TargetUrl.c_str()); + return FFmpegUtils::SampleFmt::toSampleID(AvSampleFormat); +} + +Expect AVSamplesGetBufferSize::body(const Runtime::CallingFrame &, + int32_t NbChannels, + int32_t NbSamples, + uint32_t SampleFormatId, + int32_t Align) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + return av_samples_get_buffer_size(nullptr, NbChannels, NbSamples, + AvSampleFormat, + Align); // linesize is NULL in RustSDK. +} + +Expect +AVSamplesAllocArrayAndSamples::body(const Runtime::CallingFrame &Frame, + uint32_t BufferPtr, uint32_t LinesizePtr, + int32_t NbChannels, int32_t NbSamples, + uint32_t SampleFmtId, int32_t Align) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(BufId, MemInst, uint32_t, BufferPtr, ""); + MEM_PTR_CHECK(LineSize, MemInst, int32_t, LinesizePtr, ""); + + FFMPEG_PTR_FETCH(Buf, *BufId, uint8_t *); + int LineSizeValue = 0; + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + int Res = av_samples_alloc_array_and_samples( + &Buf, &LineSizeValue, NbChannels, NbSamples, AvSampleFormat, Align); + + *LineSize = LineSizeValue; + FFMPEG_PTR_STORE(Buf, BufId); + return Res; +} + +Expect AVGetSampleFmtNameLength::body(const Runtime::CallingFrame &, + uint32_t SampleFmtId) { + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + + const char *Name = av_get_sample_fmt_name(SampleFmt); + return strlen(Name); +} + +Expect AVGetSampleFmtName::body(const Runtime::CallingFrame &Frame, + uint32_t SampleFmtId, + uint32_t SampleFmtNamePtr, + uint32_t SampleFmtNameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(SampleFmtBuf, MemInst, char, SampleFmtNamePtr, + SampleFmtNameLen, ""); + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + const char *Name = av_get_sample_fmt_name(SampleFmt); + std::copy_n(Name, SampleFmtNameLen, SampleFmtBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVGetSampleFmtMask::body(const Runtime::CallingFrame &, + uint32_t SampleFmtId) { + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + return static_cast(SampleFmt); +} + +Expect AVFreep::body(const Runtime::CallingFrame &, + uint32_t BufferId) { + FFMPEG_PTR_FETCH(Buffer, BufferId, uint8_t *); + av_freep(Buffer); + FFMPEG_PTR_DELETE(BufferId); + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/samplefmt.h b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h new file mode 100644 index 00000000..373ec2b7 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVGetPlanarSampleFmt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVGetPackedSampleFmt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVSampleFmtIsPlanar : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVGetBytesPerSample : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVGetSampleFmt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t Str, + uint32_t StrLen); +}; + +class AVSamplesGetBufferSize : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, int32_t NbChannels, + int32_t NbSamples, uint32_t SampleFormatId, + int32_t Align); +}; + +class AVSamplesAllocArrayAndSamples + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufferPtr, + uint32_t LinesizePtr, int32_t NbChannels, + int32_t NbSamples, uint32_t SampleFmtId, int32_t Align); +}; + +class AVGetSampleFmtNameLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFmtId); +}; + +class AVGetSampleFmtName : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFmtId, + uint32_t SampleFmtNamePtr, uint32_t SampleFmtNameLen); +}; + +class AVGetSampleFmtMask : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFmtId); +}; + +class AVFreep : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufferId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/bindings.h b/plugins/wasmedge_ffmpeg/bindings.h new file mode 100644 index 00000000..22b21370 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/bindings.h @@ -0,0 +1,4517 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +extern "C" { +#include "libavcodec/avcodec.h" +#include "libavutil/avutil.h" +#include "libavutil/opt.h" +#include "libswresample/swresample.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace FFmpegUtils { +class MediaType { +public: + static AVMediaType intoMediaType(int32_t MediaTypeId) { + switch (MediaTypeId) { + case 0: + return AVMEDIA_TYPE_VIDEO; + case 1: + return AVMEDIA_TYPE_AUDIO; + case 2: + return AVMEDIA_TYPE_DATA; + case 3: + return AVMEDIA_TYPE_SUBTITLE; + case 4: + return AVMEDIA_TYPE_ATTACHMENT; + case 5: + return AVMEDIA_TYPE_NB; + default: + return AVMEDIA_TYPE_UNKNOWN; + } + } + + static int32_t fromMediaType(AVMediaType MediaType) { + switch (MediaType) { + case AVMEDIA_TYPE_VIDEO: + return 0; + case AVMEDIA_TYPE_AUDIO: + return 1; + case AVMEDIA_TYPE_DATA: + return 2; + case AVMEDIA_TYPE_SUBTITLE: + return 3; + case 4: + return AVMEDIA_TYPE_ATTACHMENT; + case 5: + return AVMEDIA_TYPE_NB; + default: + return AVMEDIA_TYPE_UNKNOWN; + } + } +}; + +class CodecID { +public: + static AVCodecID intoAVCodecID(uint32_t AvCodecIndex) { + switch (AvCodecIndex) { + case 0: + return AV_CODEC_ID_NONE; + case 1: + return AV_CODEC_ID_MPEG1VIDEO; + case 2: + return AV_CODEC_ID_MPEG2VIDEO; + case 3: + return AV_CODEC_ID_H261; + case 4: + return AV_CODEC_ID_H263; + case 5: + return AV_CODEC_ID_RV10; + case 6: + return AV_CODEC_ID_RV20; + case 7: + return AV_CODEC_ID_MJPEG; + case 8: + return AV_CODEC_ID_MJPEGB; + case 9: + return AV_CODEC_ID_LJPEG; + case 10: + return AV_CODEC_ID_SP5X; + case 11: + return AV_CODEC_ID_JPEGLS; + case 12: + return AV_CODEC_ID_MPEG4; + case 13: + return AV_CODEC_ID_RAWVIDEO; + case 14: + return AV_CODEC_ID_MSMPEG4V1; + case 15: + return AV_CODEC_ID_MSMPEG4V2; + case 16: + return AV_CODEC_ID_MSMPEG4V3; + case 17: + return AV_CODEC_ID_WMV1; + case 18: + return AV_CODEC_ID_WMV2; + case 19: + return AV_CODEC_ID_H263P; + case 20: + return AV_CODEC_ID_H263I; + case 21: + return AV_CODEC_ID_FLV1; + case 22: + return AV_CODEC_ID_SVQ1; + case 23: + return AV_CODEC_ID_SVQ3; + case 24: + return AV_CODEC_ID_DVVIDEO; + case 25: + return AV_CODEC_ID_HUFFYUV; + case 26: + return AV_CODEC_ID_CYUV; + case 27: + return AV_CODEC_ID_H264; + case 28: + return AV_CODEC_ID_INDEO3; + case 29: + return AV_CODEC_ID_VP3; + case 30: + return AV_CODEC_ID_THEORA; + case 31: + return AV_CODEC_ID_ASV1; + case 32: + return AV_CODEC_ID_ASV2; + case 33: + return AV_CODEC_ID_FFV1; + case 34: + return AV_CODEC_ID_4XM; + case 35: + return AV_CODEC_ID_VCR1; + case 36: + return AV_CODEC_ID_CLJR; + case 37: + return AV_CODEC_ID_MDEC; + case 38: + return AV_CODEC_ID_ROQ; + case 39: + return AV_CODEC_ID_INTERPLAY_VIDEO; + case 40: + return AV_CODEC_ID_XAN_WC3; + case 41: + return AV_CODEC_ID_XAN_WC4; + case 42: + return AV_CODEC_ID_RPZA; + case 43: + return AV_CODEC_ID_CINEPAK; + case 44: + return AV_CODEC_ID_WS_VQA; + case 45: + return AV_CODEC_ID_MSRLE; + case 46: + return AV_CODEC_ID_MSVIDEO1; + case 47: + return AV_CODEC_ID_IDCIN; + case 48: + return AV_CODEC_ID_8BPS; + case 49: + return AV_CODEC_ID_SMC; + case 50: + return AV_CODEC_ID_FLIC; + case 51: + return AV_CODEC_ID_TRUEMOTION1; + case 52: + return AV_CODEC_ID_VMDVIDEO; + case 53: + return AV_CODEC_ID_MSZH; + case 54: + return AV_CODEC_ID_ZLIB; + case 55: + return AV_CODEC_ID_QTRLE; + case 56: + return AV_CODEC_ID_TSCC; + case 57: + return AV_CODEC_ID_ULTI; + case 58: + return AV_CODEC_ID_QDRAW; + case 59: + return AV_CODEC_ID_VIXL; + case 60: + return AV_CODEC_ID_QPEG; + case 61: + return AV_CODEC_ID_PNG; + case 62: + return AV_CODEC_ID_PPM; + case 63: + return AV_CODEC_ID_PBM; + case 64: + return AV_CODEC_ID_PGM; + case 65: + return AV_CODEC_ID_PGMYUV; + case 66: + return AV_CODEC_ID_PAM; + case 67: + return AV_CODEC_ID_FFVHUFF; + case 68: + return AV_CODEC_ID_RV30; + case 69: + return AV_CODEC_ID_RV40; + case 70: + return AV_CODEC_ID_VC1; + case 71: + return AV_CODEC_ID_WMV3; + case 72: + return AV_CODEC_ID_LOCO; + case 73: + return AV_CODEC_ID_WNV1; + case 74: + return AV_CODEC_ID_AASC; + case 75: + return AV_CODEC_ID_INDEO2; + case 76: + return AV_CODEC_ID_FRAPS; + case 77: + return AV_CODEC_ID_TRUEMOTION2; + case 78: + return AV_CODEC_ID_BMP; + case 79: + return AV_CODEC_ID_CSCD; + case 80: + return AV_CODEC_ID_MMVIDEO; + case 81: + return AV_CODEC_ID_ZMBV; + case 82: + return AV_CODEC_ID_AVS; + case 83: + return AV_CODEC_ID_SMACKVIDEO; + case 84: + return AV_CODEC_ID_NUV; + case 85: + return AV_CODEC_ID_KMVC; + case 86: + return AV_CODEC_ID_FLASHSV; + case 87: + return AV_CODEC_ID_CAVS; + case 88: + return AV_CODEC_ID_JPEG2000; + case 89: + return AV_CODEC_ID_VMNC; + case 90: + return AV_CODEC_ID_VP5; + case 91: + return AV_CODEC_ID_VP6; + case 92: + return AV_CODEC_ID_VP6F; + case 93: + return AV_CODEC_ID_TARGA; + case 94: + return AV_CODEC_ID_DSICINVIDEO; + case 95: + return AV_CODEC_ID_TIERTEXSEQVIDEO; + case 96: + return AV_CODEC_ID_TIFF; + case 97: + return AV_CODEC_ID_GIF; + case 98: + return AV_CODEC_ID_DXA; + case 99: + return AV_CODEC_ID_DNXHD; + case 100: + return AV_CODEC_ID_THP; + case 101: + return AV_CODEC_ID_SGI; + case 102: + return AV_CODEC_ID_C93; + case 103: + return AV_CODEC_ID_BETHSOFTVID; + case 104: + return AV_CODEC_ID_PTX; + case 105: + return AV_CODEC_ID_TXD; + case 106: + return AV_CODEC_ID_VP6A; + case 107: + return AV_CODEC_ID_AMV; + case 108: + return AV_CODEC_ID_VB; + case 109: + return AV_CODEC_ID_PCX; + case 110: + return AV_CODEC_ID_SUNRAST; + case 111: + return AV_CODEC_ID_INDEO4; + case 112: + return AV_CODEC_ID_INDEO5; + case 113: + return AV_CODEC_ID_MIMIC; + case 114: + return AV_CODEC_ID_RL2; + case 115: + return AV_CODEC_ID_ESCAPE124; + case 116: + return AV_CODEC_ID_DIRAC; + case 117: + return AV_CODEC_ID_BFI; + case 118: + return AV_CODEC_ID_CMV; + case 119: + return AV_CODEC_ID_MOTIONPIXELS; + case 120: + return AV_CODEC_ID_TGV; + case 121: + return AV_CODEC_ID_TGQ; + case 122: + return AV_CODEC_ID_TQI; + case 123: + return AV_CODEC_ID_AURA; + case 124: + return AV_CODEC_ID_AURA2; + case 125: + return AV_CODEC_ID_V210X; + case 126: + return AV_CODEC_ID_TMV; + case 127: + return AV_CODEC_ID_V210; + case 128: + return AV_CODEC_ID_DPX; + case 129: + return AV_CODEC_ID_MAD; + case 130: + return AV_CODEC_ID_FRWU; + case 131: + return AV_CODEC_ID_FLASHSV2; + case 132: + return AV_CODEC_ID_CDGRAPHICS; + case 133: + return AV_CODEC_ID_R210; + case 134: + return AV_CODEC_ID_ANM; + case 135: + return AV_CODEC_ID_BINKVIDEO; + case 136: + return AV_CODEC_ID_IFF_ILBM; + case 137: + return AV_CODEC_ID_IFF_ILBM; + case 138: + return AV_CODEC_ID_KGV1; + case 139: + return AV_CODEC_ID_YOP; + case 140: + return AV_CODEC_ID_VP8; + case 141: + return AV_CODEC_ID_PICTOR; + case 142: + return AV_CODEC_ID_ANSI; + case 143: + return AV_CODEC_ID_A64_MULTI; + case 144: + return AV_CODEC_ID_A64_MULTI5; + case 145: + return AV_CODEC_ID_R10K; + case 146: + return AV_CODEC_ID_MXPEG; + case 147: + return AV_CODEC_ID_LAGARITH; + case 148: + return AV_CODEC_ID_PRORES; + case 149: + return AV_CODEC_ID_JV; + case 150: + return AV_CODEC_ID_DFA; + case 151: + return AV_CODEC_ID_WMV3IMAGE; + case 152: + return AV_CODEC_ID_VC1IMAGE; + case 153: + return AV_CODEC_ID_UTVIDEO; + case 154: + return AV_CODEC_ID_BMV_VIDEO; + case 155: + return AV_CODEC_ID_VBLE; + case 156: + return AV_CODEC_ID_DXTORY; + case 157: + return AV_CODEC_ID_V410; + case 158: + return AV_CODEC_ID_XWD; + case 159: + return AV_CODEC_ID_CDXL; + case 160: + return AV_CODEC_ID_XBM; + case 161: + return AV_CODEC_ID_ZEROCODEC; + case 162: + return AV_CODEC_ID_MSS1; + case 163: + return AV_CODEC_ID_MSA1; + case 164: + return AV_CODEC_ID_TSCC2; + case 165: + return AV_CODEC_ID_MTS2; + case 166: + return AV_CODEC_ID_CLLC; + case 167: + return AV_CODEC_ID_MSS2; + case 168: + return AV_CODEC_ID_VP9; + case 169: + return AV_CODEC_ID_AIC; + case 170: + return AV_CODEC_ID_ESCAPE130; + case 171: + return AV_CODEC_ID_G2M; + case 172: + return AV_CODEC_ID_WEBP; + case 173: + return AV_CODEC_ID_HNM4_VIDEO; + case 174: + return AV_CODEC_ID_HEVC; + case 175: + return AV_CODEC_ID_HEVC; + case 176: + return AV_CODEC_ID_FIC; + case 177: + return AV_CODEC_ID_ALIAS_PIX; + case 178: + return AV_CODEC_ID_BRENDER_PIX; + case 179: + return AV_CODEC_ID_PAF_VIDEO; + case 180: + return AV_CODEC_ID_EXR; + case 181: + return AV_CODEC_ID_VP7; + case 182: + return AV_CODEC_ID_SANM; + case 183: + return AV_CODEC_ID_SGIRLE; + case 184: + return AV_CODEC_ID_MVC1; + case 185: + return AV_CODEC_ID_MVC2; + case 186: + return AV_CODEC_ID_HQX; + case 187: + return AV_CODEC_ID_TDSC; + case 188: + return AV_CODEC_ID_HQ_HQA; + case 189: + return AV_CODEC_ID_HAP; + case 190: + return AV_CODEC_ID_DDS; + case 191: + return AV_CODEC_ID_DXV; + case 192: + return AV_CODEC_ID_SCREENPRESSO; + case 193: + return AV_CODEC_ID_RSCC; + /////////////////////////////// + // case 194: + // return AV_CODEC_ID_Y41P; + // case 194: + // return AV_CODEC_ID_AVS2; + case 194: + return AV_CODEC_ID_Y41P; + case 195: + return AV_CODEC_ID_AVRP; + case 196: + return AV_CODEC_ID_012V; + case 197: + return AV_CODEC_ID_AVUI; + case 199: + return AV_CODEC_ID_TARGA_Y216; + case 200: + return AV_CODEC_ID_V308; + case 201: + return AV_CODEC_ID_V408; + case 202: + return AV_CODEC_ID_YUV4; + case 203: + return AV_CODEC_ID_AVRN; + case 204: + return AV_CODEC_ID_CPIA; + case 205: + return AV_CODEC_ID_XFACE; + case 206: + return AV_CODEC_ID_SNOW; + case 207: + return AV_CODEC_ID_SMVJPEG; + case 208: + return AV_CODEC_ID_APNG; + case 209: + return AV_CODEC_ID_DAALA; + case 210: + return AV_CODEC_ID_CFHD; + case 211: + return AV_CODEC_ID_TRUEMOTION2RT; + case 212: + return AV_CODEC_ID_M101; + case 213: + return AV_CODEC_ID_MAGICYUV; + case 214: + return AV_CODEC_ID_SHEERVIDEO; + case 215: + return AV_CODEC_ID_YLC; + case 216: + return AV_CODEC_ID_PCM_S16LE; + case 217: + return AV_CODEC_ID_PCM_S16BE; + case 218: + return AV_CODEC_ID_PCM_U16LE; + case 219: + return AV_CODEC_ID_PCM_U16BE; + case 220: + return AV_CODEC_ID_PCM_S8; + case 221: + return AV_CODEC_ID_PCM_U8; + case 222: + return AV_CODEC_ID_PCM_MULAW; + case 223: + return AV_CODEC_ID_PCM_ALAW; + case 224: + return AV_CODEC_ID_PCM_S32LE; + case 225: + return AV_CODEC_ID_PCM_S32BE; + case 226: + return AV_CODEC_ID_PCM_U32LE; + case 227: + return AV_CODEC_ID_PCM_U32BE; + case 228: + return AV_CODEC_ID_PCM_S24LE; + case 229: + return AV_CODEC_ID_PCM_S24BE; + case 230: + return AV_CODEC_ID_PCM_U24LE; + case 231: + return AV_CODEC_ID_PCM_U24BE; + case 232: + return AV_CODEC_ID_PCM_S24DAUD; + case 233: + return AV_CODEC_ID_PCM_ZORK; + case 234: + return AV_CODEC_ID_PCM_S16LE_PLANAR; + case 235: + return AV_CODEC_ID_PCM_DVD; + case 236: + return AV_CODEC_ID_PCM_F32BE; + case 237: + return AV_CODEC_ID_PCM_F32LE; + case 238: + return AV_CODEC_ID_PCM_F64BE; + case 239: + return AV_CODEC_ID_PCM_F64LE; + case 240: + return AV_CODEC_ID_PCM_BLURAY; + case 241: + return AV_CODEC_ID_PCM_LXF; + case 242: + return AV_CODEC_ID_S302M; + case 243: + return AV_CODEC_ID_PCM_S8_PLANAR; + case 244: + return AV_CODEC_ID_PCM_S24LE_PLANAR; + case 245: + return AV_CODEC_ID_PCM_S32LE_PLANAR; + case 246: + return AV_CODEC_ID_PCM_S16BE_PLANAR; + case 247: + return AV_CODEC_ID_PCM_S64LE; + case 248: + return AV_CODEC_ID_PCM_S64BE; + case 249: + return AV_CODEC_ID_ADPCM_IMA_QT; + case 250: + return AV_CODEC_ID_ADPCM_IMA_WAV; + case 251: + return AV_CODEC_ID_ADPCM_IMA_DK3; + case 252: + return AV_CODEC_ID_ADPCM_IMA_DK4; + case 253: + return AV_CODEC_ID_ADPCM_IMA_WS; + case 254: + return AV_CODEC_ID_ADPCM_IMA_SMJPEG; + case 255: + return AV_CODEC_ID_ADPCM_MS; + case 256: + return AV_CODEC_ID_ADPCM_4XM; + case 257: + return AV_CODEC_ID_ADPCM_XA; + case 258: + return AV_CODEC_ID_ADPCM_ADX; + case 259: + return AV_CODEC_ID_ADPCM_EA; + case 260: + return AV_CODEC_ID_ADPCM_G726; + case 261: + return AV_CODEC_ID_ADPCM_CT; + case 262: + return AV_CODEC_ID_ADPCM_SWF; + case 263: + return AV_CODEC_ID_ADPCM_YAMAHA; + case 264: + return AV_CODEC_ID_ADPCM_SBPRO_4; + case 265: + return AV_CODEC_ID_ADPCM_SBPRO_3; + case 266: + return AV_CODEC_ID_ADPCM_SBPRO_2; + case 267: + return AV_CODEC_ID_ADPCM_THP; + case 268: + return AV_CODEC_ID_ADPCM_IMA_AMV; + case 269: + return AV_CODEC_ID_ADPCM_EA_R1; + case 270: + return AV_CODEC_ID_ADPCM_EA_R3; + case 271: + return AV_CODEC_ID_ADPCM_EA_R2; + case 272: + return AV_CODEC_ID_ADPCM_IMA_EA_SEAD; + case 273: + return AV_CODEC_ID_ADPCM_IMA_EA_EACS; + case 274: + return AV_CODEC_ID_ADPCM_EA_XAS; + case 275: + return AV_CODEC_ID_ADPCM_EA_MAXIS_XA; + case 276: + return AV_CODEC_ID_ADPCM_IMA_ISS; + case 277: + return AV_CODEC_ID_ADPCM_G722; + case 278: + return AV_CODEC_ID_ADPCM_IMA_APC; + case 279: + return AV_CODEC_ID_ADPCM_VIMA; + case 280: + return AV_CODEC_ID_ADPCM_AFC; + case 281: + return AV_CODEC_ID_ADPCM_IMA_OKI; + case 282: + return AV_CODEC_ID_ADPCM_DTK; + case 283: + return AV_CODEC_ID_ADPCM_IMA_RAD; + case 284: + return AV_CODEC_ID_ADPCM_G726LE; + case 285: + return AV_CODEC_ID_ADPCM_THP_LE; + case 286: + return AV_CODEC_ID_ADPCM_PSX; + case 287: + return AV_CODEC_ID_ADPCM_AICA; + case 288: + return AV_CODEC_ID_ADPCM_IMA_DAT4; + case 289: + return AV_CODEC_ID_ADPCM_MTAF; + case 290: + return AV_CODEC_ID_AMR_NB; + case 291: + return AV_CODEC_ID_AMR_WB; + case 292: + return AV_CODEC_ID_RA_144; + case 293: + return AV_CODEC_ID_RA_288; + case 294: + return AV_CODEC_ID_ROQ_DPCM; + case 295: + return AV_CODEC_ID_INTERPLAY_DPCM; + case 296: + return AV_CODEC_ID_XAN_DPCM; + case 297: + return AV_CODEC_ID_SOL_DPCM; + case 298: + return AV_CODEC_ID_SDX2_DPCM; + case 299: + return AV_CODEC_ID_MP2; + case 300: + return AV_CODEC_ID_MP3; + case 301: + return AV_CODEC_ID_AAC; + case 302: + return AV_CODEC_ID_AC3; + case 303: + return AV_CODEC_ID_DTS; + case 304: + return AV_CODEC_ID_VORBIS; + case 305: + return AV_CODEC_ID_DVAUDIO; + case 306: + return AV_CODEC_ID_WMAV1; + case 307: + return AV_CODEC_ID_WMAV2; + case 308: + return AV_CODEC_ID_MACE3; + case 309: + return AV_CODEC_ID_MACE6; + case 310: + return AV_CODEC_ID_VMDAUDIO; + case 311: + return AV_CODEC_ID_FLAC; + case 312: + return AV_CODEC_ID_MP3ADU; + case 313: + return AV_CODEC_ID_MP3ON4; + case 314: + return AV_CODEC_ID_SHORTEN; + case 315: + return AV_CODEC_ID_ALAC; + case 316: + return AV_CODEC_ID_WESTWOOD_SND1; + case 317: + return AV_CODEC_ID_GSM; + case 318: + return AV_CODEC_ID_QDM2; + case 319: + return AV_CODEC_ID_COOK; + case 320: + return AV_CODEC_ID_TRUESPEECH; + case 321: + return AV_CODEC_ID_TTA; + case 322: + return AV_CODEC_ID_SMACKAUDIO; + case 323: + return AV_CODEC_ID_QCELP; + case 324: + return AV_CODEC_ID_WAVPACK; + case 325: + return AV_CODEC_ID_DSICINAUDIO; + case 326: + return AV_CODEC_ID_IMC; + case 327: + return AV_CODEC_ID_MUSEPACK7; + case 328: + return AV_CODEC_ID_MLP; + case 329: + return AV_CODEC_ID_GSM_MS; + case 330: + return AV_CODEC_ID_ATRAC3; + // #[cfg(feature = "ff_api_voxware")] + // case 331: + // return AV_CODEC_ID_VOXWARE; + case 332: + return AV_CODEC_ID_APE; + case 333: + return AV_CODEC_ID_NELLYMOSER; + case 334: + return AV_CODEC_ID_MUSEPACK8; + case 335: + return AV_CODEC_ID_SPEEX; + case 336: + return AV_CODEC_ID_WMAVOICE; + case 337: + return AV_CODEC_ID_WMAPRO; + case 338: + return AV_CODEC_ID_WMALOSSLESS; + case 339: + return AV_CODEC_ID_ATRAC3P; + case 340: + return AV_CODEC_ID_EAC3; + case 341: + return AV_CODEC_ID_SIPR; + case 342: + return AV_CODEC_ID_MP1; + case 343: + return AV_CODEC_ID_TWINVQ; + case 344: + return AV_CODEC_ID_TRUEHD; + case 345: + return AV_CODEC_ID_MP4ALS; + case 346: + return AV_CODEC_ID_ATRAC1; + case 347: + return AV_CODEC_ID_BINKAUDIO_RDFT; + case 348: + return AV_CODEC_ID_BINKAUDIO_DCT; + case 349: + return AV_CODEC_ID_AAC_LATM; + case 350: + return AV_CODEC_ID_QDMC; + case 351: + return AV_CODEC_ID_CELT; + case 352: + return AV_CODEC_ID_G723_1; + case 353: + return AV_CODEC_ID_G729; + case 354: + return AV_CODEC_ID_8SVX_EXP; + case 355: + return AV_CODEC_ID_8SVX_FIB; + case 356: + return AV_CODEC_ID_BMV_AUDIO; + case 357: + return AV_CODEC_ID_RALF; + case 358: + return AV_CODEC_ID_IAC; + case 359: + return AV_CODEC_ID_ILBC; + case 360: + return AV_CODEC_ID_OPUS; + case 361: + return AV_CODEC_ID_COMFORT_NOISE; + case 362: + return AV_CODEC_ID_TAK; + case 363: + return AV_CODEC_ID_METASOUND; + case 364: + return AV_CODEC_ID_PAF_AUDIO; + case 365: + return AV_CODEC_ID_ON2AVC; + case 366: + return AV_CODEC_ID_DSS_SP; + case 367: + return AV_CODEC_ID_CODEC2; + case 368: + return AV_CODEC_ID_FFWAVESYNTH; + case 369: + return AV_CODEC_ID_SONIC; + case 370: + return AV_CODEC_ID_SONIC_LS; + case 371: + return AV_CODEC_ID_EVRC; + case 372: + return AV_CODEC_ID_SMV; + case 373: + return AV_CODEC_ID_DSD_LSBF; + case 374: + return AV_CODEC_ID_DSD_MSBF; + case 375: + return AV_CODEC_ID_DSD_LSBF_PLANAR; + case 376: + return AV_CODEC_ID_DSD_MSBF_PLANAR; + case 377: + return AV_CODEC_ID_4GV; + case 378: + return AV_CODEC_ID_INTERPLAY_ACM; + case 379: + return AV_CODEC_ID_XMA1; + case 380: + return AV_CODEC_ID_XMA2; + case 381: + return AV_CODEC_ID_DST; + ///////////// + ///////////// + ///////////// + case 382: + return AV_CODEC_ID_DVD_SUBTITLE; + case 383: + return AV_CODEC_ID_DVB_SUBTITLE; + case 384: + return AV_CODEC_ID_TEXT; + case 385: + return AV_CODEC_ID_XSUB; + case 386: + return AV_CODEC_ID_SSA; + case 387: + return AV_CODEC_ID_MOV_TEXT; + case 388: + return AV_CODEC_ID_HDMV_PGS_SUBTITLE; + case 389: + return AV_CODEC_ID_DVB_TELETEXT; + case 390: + return AV_CODEC_ID_SRT; + case 391: + return AV_CODEC_ID_MICRODVD; + case 392: + return AV_CODEC_ID_EIA_608; + case 393: + return AV_CODEC_ID_JACOSUB; + case 394: + return AV_CODEC_ID_SAMI; + case 395: + return AV_CODEC_ID_REALTEXT; + case 396: + return AV_CODEC_ID_STL; + case 397: + return AV_CODEC_ID_SUBVIEWER1; + case 398: + return AV_CODEC_ID_SUBVIEWER; + case 399: + return AV_CODEC_ID_SUBRIP; + case 400: + return AV_CODEC_ID_WEBVTT; + case 401: + return AV_CODEC_ID_MPL2; + case 402: + return AV_CODEC_ID_VPLAYER; + case 403: + return AV_CODEC_ID_PJS; + case 404: + return AV_CODEC_ID_ASS; + case 405: + return AV_CODEC_ID_HDMV_TEXT_SUBTITLE; + case 406: + return AV_CODEC_ID_TTF; + case 407: + return AV_CODEC_ID_SCTE_35; + case 408: + return AV_CODEC_ID_BINTEXT; + case 409: + return AV_CODEC_ID_XBIN; + case 410: + return AV_CODEC_ID_IDF; + case 411: + return AV_CODEC_ID_OTF; + case 412: + return AV_CODEC_ID_SMPTE_KLV; + case 413: + return AV_CODEC_ID_DVD_NAV; + case 414: + return AV_CODEC_ID_TIMED_ID3; + case 415: + return AV_CODEC_ID_BIN_DATA; + case 416: + return AV_CODEC_ID_PROBE; + case 417: + return AV_CODEC_ID_MPEG2TS; + case 418: + return AV_CODEC_ID_MPEG4SYSTEMS; + case 419: + return AV_CODEC_ID_FFMETADATA; + case 420: + return AV_CODEC_ID_WRAPPED_AVFRAME; + case 421: + return AV_CODEC_ID_PSD; + case 422: + return AV_CODEC_ID_PIXLET; + case 423: + return AV_CODEC_ID_SPEEDHQ; + case 424: + return AV_CODEC_ID_CLEARVIDEO; + case 425: + return AV_CODEC_ID_FMVC; + case 426: + return AV_CODEC_ID_SCPR; + case 427: + return AV_CODEC_ID_XPM; + case 428: + return AV_CODEC_ID_AV1; + case 429: + return AV_CODEC_ID_PCM_F16LE; + case 430: + return AV_CODEC_ID_PCM_F24LE; + //////////// + case 431: + return AV_CODEC_ID_ATRAC3AL; + case 432: + return AV_CODEC_ID_ATRAC3PAL; + case 433: + return AV_CODEC_ID_BITPACKED; + case 434: + return AV_CODEC_ID_MSCC; + case 435: + return AV_CODEC_ID_SRGC; + case 436: + return AV_CODEC_ID_SVG; + case 437: + return AV_CODEC_ID_GDV; + case 438: + return AV_CODEC_ID_FITS; + case 439: + return AV_CODEC_ID_GREMLIN_DPCM; + case 440: + return AV_CODEC_ID_DOLBY_E; + case 441: + return AV_CODEC_ID_APTX; + case 442: + return AV_CODEC_ID_APTX_HD; + case 443: + return AV_CODEC_ID_SBC; + case 444: + return AV_CODEC_ID_AVS2; + case 445: + return AV_CODEC_ID_IMM4; + case 446: + return AV_CODEC_ID_PROSUMER; + case 447: + return AV_CODEC_ID_MWSC; + case 448: + return AV_CODEC_ID_WCMV; + case 449: + return AV_CODEC_ID_RASC; + case 450: + return AV_CODEC_ID_PCM_VIDC; + case 451: + return AV_CODEC_ID_ATRAC9; + case 452: + return AV_CODEC_ID_TTML; + case 453: + return AV_CODEC_ID_HYMT; + case 454: + return AV_CODEC_ID_ARBC; + case 455: + return AV_CODEC_ID_AGM; + case 456: + return AV_CODEC_ID_LSCR; + case 457: + return AV_CODEC_ID_VP4; + case 458: + return AV_CODEC_ID_ADPCM_AGM; + case 459: + return AV_CODEC_ID_HCOM; + case 460: + return AV_CODEC_ID_ARIB_CAPTION; + case 461: + return AV_CODEC_ID_IMM5; + case 462: + return AV_CODEC_ID_MVDV; + case 463: + return AV_CODEC_ID_MVHA; + case 464: + return AV_CODEC_ID_CDTOONS; + case 465: + return AV_CODEC_ID_MV30; + case 466: + return AV_CODEC_ID_NOTCHLC; + case 467: + return AV_CODEC_ID_PFM; + case 468: + return AV_CODEC_ID_ARGO; + case 469: + return AV_CODEC_ID_ADPCM_IMA_SSI; + case 470: + return AV_CODEC_ID_ADPCM_ZORK; + case 471: + return AV_CODEC_ID_ADPCM_IMA_APM; + case 472: + return AV_CODEC_ID_ADPCM_IMA_ALP; + case 473: + return AV_CODEC_ID_ADPCM_IMA_MTF; + case 474: + return AV_CODEC_ID_ADPCM_IMA_CUNNING; + case 475: + return AV_CODEC_ID_DERF_DPCM; + case 476: + return AV_CODEC_ID_ACELP_KELVIN; + case 477: + return AV_CODEC_ID_MPEGH_3D_AUDIO; + case 478: + return AV_CODEC_ID_SIREN; + case 479: + return AV_CODEC_ID_HCA; + case 480: + return AV_CODEC_ID_EPG; + case 481: + return AV_CODEC_ID_AVS3; + case 482: + return AV_CODEC_ID_PGX; + case 483: + return AV_CODEC_ID_MSP2; + case 484: + return AV_CODEC_ID_VVC; + case 485: + return AV_CODEC_ID_MOBICLIP; + case 486: + return AV_CODEC_ID_PHOTOCD; + case 487: + return AV_CODEC_ID_ADPCM_ARGO; + case 488: + return AV_CODEC_ID_CRI; + case 489: + return AV_CODEC_ID_IPU; + case 490: + return AV_CODEC_ID_SIMBIOSIS_IMX; + case 491: + return AV_CODEC_ID_SGA_VIDEO; + case 492: + return AV_CODEC_ID_PCM_SGA; + case 493: + return AV_CODEC_ID_ADPCM_IMA_MOFLEX; + case 494: + return AV_CODEC_ID_FASTAUDIO; + case 495: + return AV_CODEC_ID_GEM; + case 496: + return AV_CODEC_ID_ADPCM_IMA_ACORN; + case 497: + return AV_CODEC_ID_MSNSIREN; + case 498: + return AV_CODEC_ID_VBN; + case 499: + return AV_CODEC_ID_JPEGXL; + case 500: + return AV_CODEC_ID_QOI; + case 501: + return AV_CODEC_ID_PHM; + case 502: + return AV_CODEC_ID_DFPWM; + case 503: + return AV_CODEC_ID_RADIANCE_HDR; + case 504: + return AV_CODEC_ID_WBMP; + case 505: + return AV_CODEC_ID_MEDIA100; + case 506: + return AV_CODEC_ID_VQC; + case 507: + return AV_CODEC_ID_ADPCM_XMD; + case 508: + return AV_CODEC_ID_WADY_DPCM; + case 509: + return AV_CODEC_ID_CBD2_DPCM; + case 510: + return AV_CODEC_ID_BONK; + case 511: + return AV_CODEC_ID_MISC4; + case 512: + return AV_CODEC_ID_APAC; + case 513: + return AV_CODEC_ID_FTR; + case 514: + return AV_CODEC_ID_WAVARC; + case 515: + return AV_CODEC_ID_RKA; + case 516: + return AV_CODEC_ID_VNULL; + case 517: + return AV_CODEC_ID_ANULL; + // case 518: + // return AV_CODEC_ID_MPEG2VIDEO_XVMC; + default: + return AV_CODEC_ID_NONE; + }; + } + + // Convert AVCodecID to uint32_t for rust SDK. + static uint32_t fromAVCodecID(AVCodecID AvCodecId) { + switch (AvCodecId) { + case AV_CODEC_ID_NONE: + return 0; + case AV_CODEC_ID_MPEG1VIDEO: + return 1; + case AV_CODEC_ID_MPEG2VIDEO: + return 2; + case AV_CODEC_ID_H261: + return 3; + case AV_CODEC_ID_H263: + return 4; + case AV_CODEC_ID_RV10: + return 5; + case AV_CODEC_ID_RV20: + return 6; + case AV_CODEC_ID_MJPEG: + return 7; + case AV_CODEC_ID_MJPEGB: + return 8; + case AV_CODEC_ID_LJPEG: + return 9; + case AV_CODEC_ID_SP5X: + return 10; + case AV_CODEC_ID_JPEGLS: + return 11; + case AV_CODEC_ID_MPEG4: + return 12; + case AV_CODEC_ID_RAWVIDEO: + return 13; + case AV_CODEC_ID_MSMPEG4V1: + return 14; + case AV_CODEC_ID_MSMPEG4V2: + return 15; + case AV_CODEC_ID_MSMPEG4V3: + return 16; + case AV_CODEC_ID_WMV1: + return 17; + case AV_CODEC_ID_WMV2: + return 18; + case AV_CODEC_ID_H263P: + return 19; + case AV_CODEC_ID_H263I: + return 20; + case AV_CODEC_ID_FLV1: + return 21; + case AV_CODEC_ID_SVQ1: + return 22; + case AV_CODEC_ID_SVQ3: + return 23; + case AV_CODEC_ID_DVVIDEO: + return 24; + case AV_CODEC_ID_HUFFYUV: + return 25; + case AV_CODEC_ID_CYUV: + return 26; + case AV_CODEC_ID_H264: + return 27; + case AV_CODEC_ID_INDEO3: + return 28; + case AV_CODEC_ID_VP3: + return 29; + case AV_CODEC_ID_THEORA: + return 30; + case AV_CODEC_ID_ASV1: + return 31; + case AV_CODEC_ID_ASV2: + return 32; + case AV_CODEC_ID_FFV1: + return 33; + case AV_CODEC_ID_4XM: + return 34; + case AV_CODEC_ID_VCR1: + return 35; + case AV_CODEC_ID_CLJR: + return 36; + case AV_CODEC_ID_MDEC: + return 37; + case AV_CODEC_ID_ROQ: + return 38; + case AV_CODEC_ID_INTERPLAY_VIDEO: + return 39; + case AV_CODEC_ID_XAN_WC3: + return 40; + case AV_CODEC_ID_XAN_WC4: + return 41; + case AV_CODEC_ID_RPZA: + return 42; + case AV_CODEC_ID_CINEPAK: + return 43; + case AV_CODEC_ID_WS_VQA: + return 44; + case AV_CODEC_ID_MSRLE: + return 45; + case AV_CODEC_ID_MSVIDEO1: + return 46; + case AV_CODEC_ID_IDCIN: + return 47; + case AV_CODEC_ID_8BPS: + return 48; + case AV_CODEC_ID_SMC: + return 49; + case AV_CODEC_ID_FLIC: + return 50; + case AV_CODEC_ID_TRUEMOTION1: + return 51; + case AV_CODEC_ID_VMDVIDEO: + return 52; + case AV_CODEC_ID_MSZH: + return 53; + case AV_CODEC_ID_ZLIB: + return 54; + case AV_CODEC_ID_QTRLE: + return 55; + case AV_CODEC_ID_TSCC: + return 56; + case AV_CODEC_ID_ULTI: + return 57; + case AV_CODEC_ID_QDRAW: + return 58; + case AV_CODEC_ID_VIXL: + return 59; + case AV_CODEC_ID_QPEG: + return 60; + case AV_CODEC_ID_PNG: + return 61; + case AV_CODEC_ID_PPM: + return 62; + case AV_CODEC_ID_PBM: + return 63; + case AV_CODEC_ID_PGM: + return 64; + case AV_CODEC_ID_PGMYUV: + return 65; + case AV_CODEC_ID_PAM: + return 66; + case AV_CODEC_ID_FFVHUFF: + return 67; + case AV_CODEC_ID_RV30: + return 68; + case AV_CODEC_ID_RV40: + return 69; + case AV_CODEC_ID_VC1: + return 70; + case AV_CODEC_ID_WMV3: + return 71; + case AV_CODEC_ID_LOCO: + return 72; + case AV_CODEC_ID_WNV1: + return 73; + case AV_CODEC_ID_AASC: + return 74; + case AV_CODEC_ID_INDEO2: + return 75; + case AV_CODEC_ID_FRAPS: + return 76; + case AV_CODEC_ID_TRUEMOTION2: + return 77; + case AV_CODEC_ID_BMP: + return 78; + case AV_CODEC_ID_CSCD: + return 79; + case AV_CODEC_ID_MMVIDEO: + return 80; + case AV_CODEC_ID_ZMBV: + return 81; + case AV_CODEC_ID_AVS: + return 82; + case AV_CODEC_ID_SMACKVIDEO: + return 83; + case AV_CODEC_ID_NUV: + return 84; + case AV_CODEC_ID_KMVC: + return 85; + case AV_CODEC_ID_FLASHSV: + return 86; + case AV_CODEC_ID_CAVS: + return 87; + case AV_CODEC_ID_JPEG2000: + return 88; + case AV_CODEC_ID_VMNC: + return 89; + case AV_CODEC_ID_VP5: + return 90; + case AV_CODEC_ID_VP6: + return 91; + case AV_CODEC_ID_VP6F: + return 92; + case AV_CODEC_ID_TARGA: + return 93; + case AV_CODEC_ID_DSICINVIDEO: + return 94; + case AV_CODEC_ID_TIERTEXSEQVIDEO: + return 95; + case AV_CODEC_ID_TIFF: + return 96; + case AV_CODEC_ID_GIF: + return 97; + case AV_CODEC_ID_DXA: + return 98; + case AV_CODEC_ID_DNXHD: + return 99; + case AV_CODEC_ID_THP: + return 100; + case AV_CODEC_ID_SGI: + return 101; + case AV_CODEC_ID_C93: + return 102; + case AV_CODEC_ID_BETHSOFTVID: + return 103; + case AV_CODEC_ID_PTX: + return 104; + case AV_CODEC_ID_TXD: + return 105; + case AV_CODEC_ID_VP6A: + return 106; + case AV_CODEC_ID_AMV: + return 107; + case AV_CODEC_ID_VB: + return 108; + case AV_CODEC_ID_PCX: + return 109; + case AV_CODEC_ID_SUNRAST: + return 110; + case AV_CODEC_ID_INDEO4: + return 111; + case AV_CODEC_ID_INDEO5: + return 112; + case AV_CODEC_ID_MIMIC: + return 113; + case AV_CODEC_ID_RL2: + return 114; + case AV_CODEC_ID_ESCAPE124: + return 115; + case AV_CODEC_ID_DIRAC: + return 116; + case AV_CODEC_ID_BFI: + return 117; + case AV_CODEC_ID_CMV: + return 118; + case AV_CODEC_ID_MOTIONPIXELS: + return 119; + case AV_CODEC_ID_TGV: + return 120; + case AV_CODEC_ID_TGQ: + return 121; + case AV_CODEC_ID_TQI: + return 122; + case AV_CODEC_ID_AURA: + return 123; + case AV_CODEC_ID_AURA2: + return 124; + case AV_CODEC_ID_V210X: + return 125; + case AV_CODEC_ID_TMV: + return 126; + case AV_CODEC_ID_V210: + return 127; + case AV_CODEC_ID_DPX: + return 128; + case AV_CODEC_ID_MAD: + return 129; + case AV_CODEC_ID_FRWU: + return 130; + case AV_CODEC_ID_FLASHSV2: + return 131; + case AV_CODEC_ID_CDGRAPHICS: + return 132; + case AV_CODEC_ID_R210: + return 133; + case AV_CODEC_ID_ANM: + return 134; + case AV_CODEC_ID_BINKVIDEO: + return 135; + case AV_CODEC_ID_IFF_ILBM: + return 136; + // case AV_CODEC_ID_IFF_ILBM: + // return 137; + case AV_CODEC_ID_KGV1: + return 138; + case AV_CODEC_ID_YOP: + return 139; + case AV_CODEC_ID_VP8: + return 140; + case AV_CODEC_ID_PICTOR: + return 141; + case AV_CODEC_ID_ANSI: + return 142; + case AV_CODEC_ID_A64_MULTI: + return 143; + case AV_CODEC_ID_A64_MULTI5: + return 144; + case AV_CODEC_ID_R10K: + return 145; + case AV_CODEC_ID_MXPEG: + return 146; + case AV_CODEC_ID_LAGARITH: + return 147; + case AV_CODEC_ID_PRORES: + return 148; + case AV_CODEC_ID_JV: + return 149; + case AV_CODEC_ID_DFA: + return 150; + case AV_CODEC_ID_WMV3IMAGE: + return 151; + case AV_CODEC_ID_VC1IMAGE: + return 152; + case AV_CODEC_ID_UTVIDEO: + return 153; + case AV_CODEC_ID_BMV_VIDEO: + return 154; + case AV_CODEC_ID_VBLE: + return 155; + case AV_CODEC_ID_DXTORY: + return 156; + case AV_CODEC_ID_V410: + return 157; + case AV_CODEC_ID_XWD: + return 158; + case AV_CODEC_ID_CDXL: + return 159; + case AV_CODEC_ID_XBM: + return 160; + case AV_CODEC_ID_ZEROCODEC: + return 161; + case AV_CODEC_ID_MSS1: + return 162; + case AV_CODEC_ID_MSA1: + return 163; + case AV_CODEC_ID_TSCC2: + return 164; + case AV_CODEC_ID_MTS2: + return 165; + case AV_CODEC_ID_CLLC: + return 166; + case AV_CODEC_ID_MSS2: + return 167; + case AV_CODEC_ID_VP9: + return 168; + case AV_CODEC_ID_AIC: + return 169; + case AV_CODEC_ID_ESCAPE130: + return 170; + case AV_CODEC_ID_G2M: + return 171; + case AV_CODEC_ID_WEBP: + return 172; + case AV_CODEC_ID_HNM4_VIDEO: + return 173; + case AV_CODEC_ID_HEVC: + return 174; + // case AV_CODEC_ID_HEVC: + // return 175; + case AV_CODEC_ID_FIC: + return 176; + case AV_CODEC_ID_ALIAS_PIX: + return 177; + case AV_CODEC_ID_BRENDER_PIX: + return 178; + case AV_CODEC_ID_PAF_VIDEO: + return 179; + case AV_CODEC_ID_EXR: + return 180; + case AV_CODEC_ID_VP7: + return 181; + case AV_CODEC_ID_SANM: + return 182; + case AV_CODEC_ID_SGIRLE: + return 183; + case AV_CODEC_ID_MVC1: + return 184; + case AV_CODEC_ID_MVC2: + return 185; + case AV_CODEC_ID_HQX: + return 186; + case AV_CODEC_ID_TDSC: + return 187; + case AV_CODEC_ID_HQ_HQA: + return 188; + case AV_CODEC_ID_HAP: + return 189; + case AV_CODEC_ID_DDS: + return 190; + case AV_CODEC_ID_DXV: + return 191; + case AV_CODEC_ID_SCREENPRESSO: + return 192; + case AV_CODEC_ID_RSCC: + return 193; + /////////////////////////////// + // return ; + // case AV_CODEC_ID_Y41P: + // return ; + // case AV_CODEC_ID_AVS2: + case AV_CODEC_ID_Y41P: + return 194; + case AV_CODEC_ID_AVRP: + return 195; + case AV_CODEC_ID_012V: + return 196; + case AV_CODEC_ID_AVUI: + return 197; + case AV_CODEC_ID_TARGA_Y216: + return 199; + case AV_CODEC_ID_V308: + return 200; + case AV_CODEC_ID_V408: + return 201; + case AV_CODEC_ID_YUV4: + return 202; + case AV_CODEC_ID_AVRN: + return 203; + case AV_CODEC_ID_CPIA: + return 204; + case AV_CODEC_ID_XFACE: + return 205; + case AV_CODEC_ID_SNOW: + return 206; + case AV_CODEC_ID_SMVJPEG: + return 207; + case AV_CODEC_ID_APNG: + return 208; + case AV_CODEC_ID_DAALA: + return 209; + case AV_CODEC_ID_CFHD: + return 210; + case AV_CODEC_ID_TRUEMOTION2RT: + return 211; + case AV_CODEC_ID_M101: + return 212; + case AV_CODEC_ID_MAGICYUV: + return 213; + case AV_CODEC_ID_SHEERVIDEO: + return 214; + case AV_CODEC_ID_YLC: + return 215; + // ================================= + // ================================= + // ================================= + case AV_CODEC_ID_PCM_S16LE: + return 216; + case AV_CODEC_ID_PCM_S16BE: + return 217; + case AV_CODEC_ID_PCM_U16LE: + return 218; + case AV_CODEC_ID_PCM_U16BE: + return 219; + case AV_CODEC_ID_PCM_S8: + return 220; + case AV_CODEC_ID_PCM_U8: + return 221; + case AV_CODEC_ID_PCM_MULAW: + return 222; + case AV_CODEC_ID_PCM_ALAW: + return 223; + case AV_CODEC_ID_PCM_S32LE: + return 224; + case AV_CODEC_ID_PCM_S32BE: + return 225; + case AV_CODEC_ID_PCM_U32LE: + return 226; + case AV_CODEC_ID_PCM_U32BE: + return 227; + case AV_CODEC_ID_PCM_S24LE: + return 228; + case AV_CODEC_ID_PCM_S24BE: + return 229; + case AV_CODEC_ID_PCM_U24LE: + return 230; + case AV_CODEC_ID_PCM_U24BE: + return 231; + case AV_CODEC_ID_PCM_S24DAUD: + return 232; + case AV_CODEC_ID_PCM_ZORK: + return 233; + case AV_CODEC_ID_PCM_S16LE_PLANAR: + return 234; + case AV_CODEC_ID_PCM_DVD: + return 235; + case AV_CODEC_ID_PCM_F32BE: + return 236; + case AV_CODEC_ID_PCM_F32LE: + return 237; + case AV_CODEC_ID_PCM_F64BE: + return 238; + case AV_CODEC_ID_PCM_F64LE: + return 239; + case AV_CODEC_ID_PCM_BLURAY: + return 240; + case AV_CODEC_ID_PCM_LXF: + return 241; + case AV_CODEC_ID_S302M: + return 242; + case AV_CODEC_ID_PCM_S8_PLANAR: + return 243; + case AV_CODEC_ID_PCM_S24LE_PLANAR: + return 244; + case AV_CODEC_ID_PCM_S32LE_PLANAR: + return 245; + case AV_CODEC_ID_PCM_S16BE_PLANAR: + return 246; + case AV_CODEC_ID_PCM_S64LE: + return 247; + case AV_CODEC_ID_PCM_S64BE: + return 248; + case AV_CODEC_ID_ADPCM_IMA_QT: + return 249; + case AV_CODEC_ID_ADPCM_IMA_WAV: + return 250; + case AV_CODEC_ID_ADPCM_IMA_DK3: + return 251; + case AV_CODEC_ID_ADPCM_IMA_DK4: + return 252; + case AV_CODEC_ID_ADPCM_IMA_WS: + return 253; + case AV_CODEC_ID_ADPCM_IMA_SMJPEG: + return 254; + case AV_CODEC_ID_ADPCM_MS: + return 255; + case AV_CODEC_ID_ADPCM_4XM: + return 256; + case AV_CODEC_ID_ADPCM_XA: + return 257; + case AV_CODEC_ID_ADPCM_ADX: + return 258; + case AV_CODEC_ID_ADPCM_EA: + return 259; + case AV_CODEC_ID_ADPCM_G726: + return 260; + case AV_CODEC_ID_ADPCM_CT: + return 261; + case AV_CODEC_ID_ADPCM_SWF: + return 262; + case AV_CODEC_ID_ADPCM_YAMAHA: + return 263; + case AV_CODEC_ID_ADPCM_SBPRO_4: + return 264; + case AV_CODEC_ID_ADPCM_SBPRO_3: + return 265; + case AV_CODEC_ID_ADPCM_SBPRO_2: + return 266; + case AV_CODEC_ID_ADPCM_THP: + return 267; + case AV_CODEC_ID_ADPCM_IMA_AMV: + return 268; + case AV_CODEC_ID_ADPCM_EA_R1: + return 269; + case AV_CODEC_ID_ADPCM_EA_R3: + return 270; + case AV_CODEC_ID_ADPCM_EA_R2: + return 271; + case AV_CODEC_ID_ADPCM_IMA_EA_SEAD: + return 272; + case AV_CODEC_ID_ADPCM_IMA_EA_EACS: + return 273; + case AV_CODEC_ID_ADPCM_EA_XAS: + return 274; + case AV_CODEC_ID_ADPCM_EA_MAXIS_XA: + return 275; + case AV_CODEC_ID_ADPCM_IMA_ISS: + return 276; + case AV_CODEC_ID_ADPCM_G722: + return 277; + case AV_CODEC_ID_ADPCM_IMA_APC: + return 278; + case AV_CODEC_ID_ADPCM_VIMA: + return 279; + case AV_CODEC_ID_ADPCM_AFC: + return 280; + case AV_CODEC_ID_ADPCM_IMA_OKI: + return 281; + case AV_CODEC_ID_ADPCM_DTK: + return 282; + case AV_CODEC_ID_ADPCM_IMA_RAD: + return 283; + case AV_CODEC_ID_ADPCM_G726LE: + return 284; + case AV_CODEC_ID_ADPCM_THP_LE: + return 285; + case AV_CODEC_ID_ADPCM_PSX: + return 286; + case AV_CODEC_ID_ADPCM_AICA: + return 287; + case AV_CODEC_ID_ADPCM_IMA_DAT4: + return 288; + case AV_CODEC_ID_ADPCM_MTAF: + return 289; + case AV_CODEC_ID_AMR_NB: + return 290; + case AV_CODEC_ID_AMR_WB: + return 291; + case AV_CODEC_ID_RA_144: + return 292; + case AV_CODEC_ID_RA_288: + return 293; + case AV_CODEC_ID_ROQ_DPCM: + return 294; + case AV_CODEC_ID_INTERPLAY_DPCM: + return 295; + case AV_CODEC_ID_XAN_DPCM: + return 296; + case AV_CODEC_ID_SOL_DPCM: + return 297; + case AV_CODEC_ID_SDX2_DPCM: + return 298; + case AV_CODEC_ID_MP2: + return 299; + case AV_CODEC_ID_MP3: + return 300; + case AV_CODEC_ID_AAC: + return 301; + case AV_CODEC_ID_AC3: + return 302; + case AV_CODEC_ID_DTS: + return 303; + case AV_CODEC_ID_VORBIS: + return 304; + case AV_CODEC_ID_DVAUDIO: + return 305; + case AV_CODEC_ID_WMAV1: + return 306; + case AV_CODEC_ID_WMAV2: + return 307; + case AV_CODEC_ID_MACE3: + return 308; + case AV_CODEC_ID_MACE6: + return 309; + case AV_CODEC_ID_VMDAUDIO: + return 310; + case AV_CODEC_ID_FLAC: + return 311; + case AV_CODEC_ID_MP3ADU: + return 312; + case AV_CODEC_ID_MP3ON4: + return 313; + case AV_CODEC_ID_SHORTEN: + return 314; + case AV_CODEC_ID_ALAC: + return 315; + case AV_CODEC_ID_WESTWOOD_SND1: + return 316; + case AV_CODEC_ID_GSM: + return 317; + case AV_CODEC_ID_QDM2: + return 318; + case AV_CODEC_ID_COOK: + return 319; + case AV_CODEC_ID_TRUESPEECH: + return 320; + case AV_CODEC_ID_TTA: + return 321; + case AV_CODEC_ID_SMACKAUDIO: + return 322; + case AV_CODEC_ID_QCELP: + return 323; + case AV_CODEC_ID_WAVPACK: + return 324; + case AV_CODEC_ID_DSICINAUDIO: + return 325; + case AV_CODEC_ID_IMC: + return 326; + case AV_CODEC_ID_MUSEPACK7: + return 327; + case AV_CODEC_ID_MLP: + return 328; + case AV_CODEC_ID_GSM_MS: + return 329; + case AV_CODEC_ID_ATRAC3: + return 330; + // #[cfg(feature = "ff_api_voxware")] + // case AV_CODEC_ID_VOXWARE: + // return 331; + case AV_CODEC_ID_APE: + return 332; + case AV_CODEC_ID_NELLYMOSER: + return 333; + case AV_CODEC_ID_MUSEPACK8: + return 334; + case AV_CODEC_ID_SPEEX: + return 335; + case AV_CODEC_ID_WMAVOICE: + return 336; + case AV_CODEC_ID_WMAPRO: + return 337; + case AV_CODEC_ID_WMALOSSLESS: + return 338; + case AV_CODEC_ID_ATRAC3P: + return 339; + case AV_CODEC_ID_EAC3: + return 340; + case AV_CODEC_ID_SIPR: + return 341; + case AV_CODEC_ID_MP1: + return 342; + case AV_CODEC_ID_TWINVQ: + return 343; + case AV_CODEC_ID_TRUEHD: + return 344; + case AV_CODEC_ID_MP4ALS: + return 345; + case AV_CODEC_ID_ATRAC1: + return 346; + case AV_CODEC_ID_BINKAUDIO_RDFT: + return 347; + case AV_CODEC_ID_BINKAUDIO_DCT: + return 348; + case AV_CODEC_ID_AAC_LATM: + return 349; + case AV_CODEC_ID_QDMC: + return 350; + case AV_CODEC_ID_CELT: + return 351; + case AV_CODEC_ID_G723_1: + return 352; + case AV_CODEC_ID_G729: + return 353; + case AV_CODEC_ID_8SVX_EXP: + return 354; + case AV_CODEC_ID_8SVX_FIB: + return 355; + case AV_CODEC_ID_BMV_AUDIO: + return 356; + case AV_CODEC_ID_RALF: + return 357; + case AV_CODEC_ID_IAC: + return 358; + case AV_CODEC_ID_ILBC: + return 359; + case AV_CODEC_ID_OPUS: + return 360; + case AV_CODEC_ID_COMFORT_NOISE: + return 361; + case AV_CODEC_ID_TAK: + return 362; + case AV_CODEC_ID_METASOUND: + return 363; + case AV_CODEC_ID_PAF_AUDIO: + return 364; + case AV_CODEC_ID_ON2AVC: + return 365; + case AV_CODEC_ID_DSS_SP: + return 366; + case AV_CODEC_ID_CODEC2: + return 367; + case AV_CODEC_ID_FFWAVESYNTH: + return 368; + case AV_CODEC_ID_SONIC: + return 369; + case AV_CODEC_ID_SONIC_LS: + return 370; + case AV_CODEC_ID_EVRC: + return 371; + case AV_CODEC_ID_SMV: + return 372; + case AV_CODEC_ID_DSD_LSBF: + return 373; + case AV_CODEC_ID_DSD_MSBF: + return 374; + case AV_CODEC_ID_DSD_LSBF_PLANAR: + return 375; + case AV_CODEC_ID_DSD_MSBF_PLANAR: + return 376; + case AV_CODEC_ID_4GV: + return 377; + case AV_CODEC_ID_INTERPLAY_ACM: + return 378; + case AV_CODEC_ID_XMA1: + return 379; + case AV_CODEC_ID_XMA2: + return 380; + case AV_CODEC_ID_DST: + return 381; + case AV_CODEC_ID_DVD_SUBTITLE: + return 382; + case AV_CODEC_ID_DVB_SUBTITLE: + return 383; + case AV_CODEC_ID_TEXT: + return 384; + case AV_CODEC_ID_XSUB: + return 385; + case AV_CODEC_ID_SSA: + return 386; + case AV_CODEC_ID_MOV_TEXT: + return 387; + case AV_CODEC_ID_HDMV_PGS_SUBTITLE: + return 388; + case AV_CODEC_ID_DVB_TELETEXT: + return 389; + case AV_CODEC_ID_SRT: + return 390; + case AV_CODEC_ID_MICRODVD: + return 391; + case AV_CODEC_ID_EIA_608: + return 392; + case AV_CODEC_ID_JACOSUB: + return 393; + case AV_CODEC_ID_SAMI: + return 394; + case AV_CODEC_ID_REALTEXT: + return 395; + case AV_CODEC_ID_STL: + return 396; + case AV_CODEC_ID_SUBVIEWER1: + return 397; + case AV_CODEC_ID_SUBVIEWER: + return 398; + case AV_CODEC_ID_SUBRIP: + return 399; + case AV_CODEC_ID_WEBVTT: + return 400; + case AV_CODEC_ID_MPL2: + return 401; + case AV_CODEC_ID_VPLAYER: + return 402; + case AV_CODEC_ID_PJS: + return 403; + case AV_CODEC_ID_ASS: + return 404; + case AV_CODEC_ID_HDMV_TEXT_SUBTITLE: + return 405; + case AV_CODEC_ID_TTF: + return 406; + case AV_CODEC_ID_SCTE_35: + return 407; + case AV_CODEC_ID_BINTEXT: + return 408; + case AV_CODEC_ID_XBIN: + return 409; + case AV_CODEC_ID_IDF: + return 410; + case AV_CODEC_ID_OTF: + return 411; + case AV_CODEC_ID_SMPTE_KLV: + return 412; + case AV_CODEC_ID_DVD_NAV: + return 413; + case AV_CODEC_ID_TIMED_ID3: + return 414; + case AV_CODEC_ID_BIN_DATA: + return 415; + case AV_CODEC_ID_PROBE: + return 416; + case AV_CODEC_ID_MPEG2TS: + return 417; + case AV_CODEC_ID_MPEG4SYSTEMS: + return 418; + case AV_CODEC_ID_FFMETADATA: + return 419; + case AV_CODEC_ID_WRAPPED_AVFRAME: + return 420; + case AV_CODEC_ID_PSD: + return 421; + case AV_CODEC_ID_PIXLET: + return 422; + case AV_CODEC_ID_SPEEDHQ: + return 423; + case AV_CODEC_ID_CLEARVIDEO: + return 424; + case AV_CODEC_ID_FMVC: + return 425; + case AV_CODEC_ID_SCPR: + return 426; + case AV_CODEC_ID_XPM: + return 427; + case AV_CODEC_ID_AV1: + return 428; + case AV_CODEC_ID_PCM_F16LE: + return 429; + case AV_CODEC_ID_PCM_F24LE: + return 430; + //////////// + case AV_CODEC_ID_ATRAC3AL: + return 431; + case AV_CODEC_ID_ATRAC3PAL: + return 432; + case AV_CODEC_ID_BITPACKED: + return 433; + case AV_CODEC_ID_MSCC: + return 434; + case AV_CODEC_ID_SRGC: + return 435; + case AV_CODEC_ID_SVG: + return 436; + case AV_CODEC_ID_GDV: + return 437; + case AV_CODEC_ID_FITS: + return 438; + case AV_CODEC_ID_GREMLIN_DPCM: + return 439; + case AV_CODEC_ID_DOLBY_E: + return 440; + case AV_CODEC_ID_APTX: + return 441; + case AV_CODEC_ID_APTX_HD: + return 442; + case AV_CODEC_ID_SBC: + return 443; + case AV_CODEC_ID_AVS2: + return 444; + case AV_CODEC_ID_IMM4: + return 445; + case AV_CODEC_ID_PROSUMER: + return 446; + case AV_CODEC_ID_MWSC: + return 447; + case AV_CODEC_ID_WCMV: + return 448; + case AV_CODEC_ID_RASC: + return 449; + case AV_CODEC_ID_PCM_VIDC: + return 450; + case AV_CODEC_ID_ATRAC9: + return 451; + case AV_CODEC_ID_TTML: + return 452; + case AV_CODEC_ID_HYMT: + return 453; + case AV_CODEC_ID_ARBC: + return 454; + case AV_CODEC_ID_AGM: + return 455; + case AV_CODEC_ID_LSCR: + return 456; + case AV_CODEC_ID_VP4: + return 457; + case AV_CODEC_ID_ADPCM_AGM: + return 458; + case AV_CODEC_ID_HCOM: + return 459; + case AV_CODEC_ID_ARIB_CAPTION: + return 460; + case AV_CODEC_ID_IMM5: + return 461; + case AV_CODEC_ID_MVDV: + return 462; + case AV_CODEC_ID_MVHA: + return 463; + case AV_CODEC_ID_CDTOONS: + return 464; + case AV_CODEC_ID_MV30: + return 465; + case AV_CODEC_ID_NOTCHLC: + return 466; + case AV_CODEC_ID_PFM: + return 467; + case AV_CODEC_ID_ARGO: + return 468; + case AV_CODEC_ID_ADPCM_IMA_SSI: + return 469; + case AV_CODEC_ID_ADPCM_ZORK: + return 470; + case AV_CODEC_ID_ADPCM_IMA_APM: + return 471; + case AV_CODEC_ID_ADPCM_IMA_ALP: + return 472; + case AV_CODEC_ID_ADPCM_IMA_MTF: + return 473; + case AV_CODEC_ID_ADPCM_IMA_CUNNING: + return 474; + case AV_CODEC_ID_DERF_DPCM: + return 475; + case AV_CODEC_ID_ACELP_KELVIN: + return 476; + case AV_CODEC_ID_MPEGH_3D_AUDIO: + return 477; + case AV_CODEC_ID_SIREN: + return 478; + case AV_CODEC_ID_HCA: + return 479; + case AV_CODEC_ID_EPG: + return 480; + case AV_CODEC_ID_AVS3: + return 481; + case AV_CODEC_ID_PGX: + return 482; + case AV_CODEC_ID_MSP2: + return 483; + case AV_CODEC_ID_VVC: + return 484; + case AV_CODEC_ID_MOBICLIP: + return 485; + case AV_CODEC_ID_PHOTOCD: + return 486; + case AV_CODEC_ID_ADPCM_ARGO: + return 487; + case AV_CODEC_ID_CRI: + return 488; + case AV_CODEC_ID_IPU: + return 489; + case AV_CODEC_ID_SIMBIOSIS_IMX: + return 490; + case AV_CODEC_ID_SGA_VIDEO: + return 491; + case AV_CODEC_ID_PCM_SGA: + return 492; + case AV_CODEC_ID_ADPCM_IMA_MOFLEX: + return 493; + case AV_CODEC_ID_FASTAUDIO: + return 494; + case AV_CODEC_ID_GEM: + return 495; + case AV_CODEC_ID_ADPCM_IMA_ACORN: + return 496; + case AV_CODEC_ID_MSNSIREN: + return 497; + case AV_CODEC_ID_VBN: + return 498; + case AV_CODEC_ID_JPEGXL: + return 499; + case AV_CODEC_ID_QOI: + return 500; + case AV_CODEC_ID_PHM: + return 501; + case AV_CODEC_ID_DFPWM: + return 502; + case AV_CODEC_ID_RADIANCE_HDR: + return 503; + case AV_CODEC_ID_WBMP: + return 504; + case AV_CODEC_ID_MEDIA100: + return 505; + case AV_CODEC_ID_VQC: + return 506; + case AV_CODEC_ID_ADPCM_XMD: + return 507; + case AV_CODEC_ID_WADY_DPCM: + return 508; + case AV_CODEC_ID_CBD2_DPCM: + return 509; + case AV_CODEC_ID_BONK: + return 510; + case AV_CODEC_ID_MISC4: + return 511; + case AV_CODEC_ID_APAC: + return 512; + case AV_CODEC_ID_FTR: + return 513; + case AV_CODEC_ID_WAVARC: + return 514; + case AV_CODEC_ID_RKA: + return 515; + case AV_CODEC_ID_VNULL: + return 516; + case AV_CODEC_ID_ANULL: + return 517; + // case AV_CODEC_ID_MPEG2VIDEO_XVMC: + // return 518; + default: + return 0; + } + } +}; + +class PixFmt { + +public: + static uint32_t fromAVPixFmt(AVPixelFormat AvPixelFormat) { + switch (AvPixelFormat) { + case AV_PIX_FMT_NONE: + return 0; + case AV_PIX_FMT_YUV420P: + return 1; + case AV_PIX_FMT_YUYV422: + return 2; + case AV_PIX_FMT_RGB24: + return 3; + case AV_PIX_FMT_BGR24: + return 4; + case AV_PIX_FMT_YUV422P: + return 5; + case AV_PIX_FMT_YUV444P: + return 7; + case AV_PIX_FMT_YUV410P: + return 8; + case AV_PIX_FMT_YUV411P: + return 9; + case AV_PIX_FMT_GRAY8: + return 10; + case AV_PIX_FMT_MONOWHITE: + return 11; + case AV_PIX_FMT_MONOBLACK: + return 12; + case AV_PIX_FMT_PAL8: + return 13; + case AV_PIX_FMT_YUVJ420P: + return 14; + case AV_PIX_FMT_YUVJ422P: + return 15; + case AV_PIX_FMT_YUVJ444P: + return 16; + // case AV_PIX_FMT_XVMC_MPEG2_MC : // Lower FFmpeg Version + // return 17; + // case AV_PIX_FMT_XVMC_MPEG2_IDCT : + // return 18; + case AV_PIX_FMT_UYVY422: + return 19; + case AV_PIX_FMT_UYYVYY411: + return 20; + case AV_PIX_FMT_BGR8: + return 21; + case AV_PIX_FMT_BGR4: + return 22; + case AV_PIX_FMT_BGR4_BYTE: + return 23; + case AV_PIX_FMT_RGB8: + return 24; + case AV_PIX_FMT_RGB4: + return 25; + case AV_PIX_FMT_RGB4_BYTE: + return 26; + case AV_PIX_FMT_NV12: + return 27; + case AV_PIX_FMT_NV21: + return 28; + case AV_PIX_FMT_ARGB: // Big Endian + return 29; + case AV_PIX_FMT_RGBA: // Big + return 30; + case AV_PIX_FMT_ABGR: // Big + return 31; + case AV_PIX_FMT_BGRA: // little + return 32; + case AV_PIX_FMT_GRAY16BE: // big + return 33; + case AV_PIX_FMT_GRAY16LE: + return 34; + case AV_PIX_FMT_YUV440P: + return 35; + case AV_PIX_FMT_YUVJ440P: + return 36; + case AV_PIX_FMT_YUVA420P: + return 37; + // case AV_PIX_FMT_VDPAU_H264 : + // return 38; + // case AV_PIX_FMT_VDPAU_MPEG1 : + // return 39; + // case AV_PIX_FMT_VDPAU_MPEG2 : + // return 40; + // case AV_PIX_FMT_VDPAU_WMV3 : // Conditional compile. + // return 41; + // case AV_PIX_FMT_VDPAU_VC1 : // ff_api_vdpau is present + // return 42; + case AV_PIX_FMT_RGB48BE: + return 43; + case AV_PIX_FMT_RGB48LE: + return 44; + case AV_PIX_FMT_RGB565BE: + return 45; + case AV_PIX_FMT_RGB565LE: + return 46; + case AV_PIX_FMT_RGB555BE: + return 47; + case AV_PIX_FMT_RGB555LE: + return 48; + case AV_PIX_FMT_BGR565BE: + return 49; + case AV_PIX_FMT_BGR565LE: + return 50; + case AV_PIX_FMT_BGR555BE: + return 51; + case AV_PIX_FMT_BGR555LE: + return 52; + // case AV_PIX_FMT_VAAPI_MOCO : + // return 53; + // case AV_PIX_FMT_VAAPI_IDCT : + // return 54; + // case AV_PIX_FMT_VAAPI_VLD : + // return 55; + // case AV_PIX_FMT_VAAPI : // ff_api_vdpau is present + // return 56; + case AV_PIX_FMT_YUV420P16LE: + return 57; + case AV_PIX_FMT_YUV420P16BE: + return 58; + case AV_PIX_FMT_YUV422P16LE: + return 59; + case AV_PIX_FMT_YUV422P16BE: + return 60; + case AV_PIX_FMT_YUV444P16LE: + return 61; + case AV_PIX_FMT_YUV444P16BE: + return 62; + // case AV_PIX_FMT_VDPAU_MPEG4 : // ff_api_vdpau is present + // return 63; + case AV_PIX_FMT_DXVA2_VLD: + return 64; + case AV_PIX_FMT_RGB444LE: + return 65; + case AV_PIX_FMT_RGB444BE: + return 66; + case AV_PIX_FMT_BGR444LE: + return 67; + case AV_PIX_FMT_BGR444BE: + return 68; + case AV_PIX_FMT_YA8: + return 69; + case AV_PIX_FMT_BGR48BE: + return 70; + case AV_PIX_FMT_BGR48LE: + return 71; + case AV_PIX_FMT_YUV420P9BE: + return 72; + case AV_PIX_FMT_YUV420P9LE: + return 73; + case AV_PIX_FMT_YUV420P10BE: + return 74; + case AV_PIX_FMT_YUV420P10LE: + return 75; + case AV_PIX_FMT_YUV422P10BE: + return 76; + case AV_PIX_FMT_YUV422P10LE: + return 77; + case AV_PIX_FMT_YUV444P9BE: + return 78; + case AV_PIX_FMT_YUV444P9LE: + return 79; + case AV_PIX_FMT_YUV444P10BE: + return 80; + case AV_PIX_FMT_YUV444P10LE: + return 81; + case AV_PIX_FMT_YUV422P9BE: + return 82; + case AV_PIX_FMT_YUV422P9LE: + return 83; + // case AV_PIX_FMT_VDA_VLD : // Lower than ffmpeg version 4 + // return 84; + case AV_PIX_FMT_GBRP: + return 85; + case AV_PIX_FMT_GBRP9BE: + return 86; + case AV_PIX_FMT_GBRP9LE: + return 87; + case AV_PIX_FMT_GBRP10BE: + return 88; + case AV_PIX_FMT_GBRP10LE: + return 89; + case AV_PIX_FMT_GBRP16BE: + return 90; + case AV_PIX_FMT_GBRP16LE: + return 91; + case AV_PIX_FMT_YUVA420P9BE: + return 92; + case AV_PIX_FMT_YUVA420P9LE: + return 93; + case AV_PIX_FMT_YUVA422P9BE: + return 94; + case AV_PIX_FMT_YUVA422P9LE: + return 95; + case AV_PIX_FMT_YUVA444P9BE: + return 96; + case AV_PIX_FMT_YUVA444P9LE: + return 97; + case AV_PIX_FMT_YUVA420P10BE: + return 98; + case AV_PIX_FMT_YUVA420P10LE: + return 99; + case AV_PIX_FMT_YUVA422P10BE: + return 100; + case AV_PIX_FMT_YUVA422P10LE: + return 101; + case AV_PIX_FMT_YUVA444P10BE: + return 102; + case AV_PIX_FMT_YUVA444P10LE: + return 103; + case AV_PIX_FMT_YUVA420P16BE: + return 104; + case AV_PIX_FMT_YUVA420P16LE: + return 105; + case AV_PIX_FMT_YUVA422P16BE: + return 106; + case AV_PIX_FMT_YUVA422P16LE: + return 107; + case AV_PIX_FMT_YUVA444P16BE: + return 108; + case AV_PIX_FMT_YUVA444P16LE: + return 109; + case AV_PIX_FMT_VDPAU: + return 110; + case AV_PIX_FMT_XYZ12LE: + return 111; + case AV_PIX_FMT_XYZ12BE: + return 112; + case AV_PIX_FMT_NV16: + return 113; + case AV_PIX_FMT_NV20LE: + return 114; + case AV_PIX_FMT_NV20BE: + return 115; + case AV_PIX_FMT_RGBA64BE: + return 116; + case AV_PIX_FMT_RGBA64LE: + return 117; + case AV_PIX_FMT_BGRA64BE: + return 118; + case AV_PIX_FMT_BGRA64LE: + return 119; + case AV_PIX_FMT_YVYU422: + return 120; + // case AV_PIX_FMT_VDA : // Lower than ffmpeg version 4. + // return 121; + case AV_PIX_FMT_YA16BE: // big + return 122; + case AV_PIX_FMT_YA16LE: + return 123; + case AV_PIX_FMT_QSV: + return 124; + case AV_PIX_FMT_MMAL: + return 125; + case AV_PIX_FMT_D3D11VA_VLD: + return 126; + case AV_PIX_FMT_CUDA: + return 127; + case AV_PIX_FMT_0RGB: // big + return 128; + case AV_PIX_FMT_RGB0: + return 129; + case AV_PIX_FMT_0BGR: // big + return 130; + case AV_PIX_FMT_BGR0: + return 131; + case AV_PIX_FMT_YUVA444P: + return 132; + case AV_PIX_FMT_YUVA422P: + return 133; + case AV_PIX_FMT_YUV420P12BE: + return 134; + case AV_PIX_FMT_YUV420P12LE: + return 135; + case AV_PIX_FMT_YUV420P14BE: + return 136; + case AV_PIX_FMT_YUV420P14LE: + return 137; + case AV_PIX_FMT_YUV422P12BE: + return 138; + case AV_PIX_FMT_YUV422P12LE: + return 139; + case AV_PIX_FMT_YUV422P14BE: + return 140; + case AV_PIX_FMT_YUV422P14LE: + return 141; + case AV_PIX_FMT_YUV444P12BE: + return 142; + case AV_PIX_FMT_YUV444P12LE: + return 143; + case AV_PIX_FMT_YUV444P14BE: + return 144; + case AV_PIX_FMT_YUV444P14LE: + return 146; + case AV_PIX_FMT_GBRP12BE: + return 147; + case AV_PIX_FMT_GBRP12LE: + return 148; + case AV_PIX_FMT_GBRP14BE: + return 149; + case AV_PIX_FMT_GBRP14LE: + return 150; + case AV_PIX_FMT_GBRAP: + return 151; + case AV_PIX_FMT_GBRAP16BE: + return 152; + case AV_PIX_FMT_GBRAP16LE: + return 153; + case AV_PIX_FMT_YUVJ411P: + return 154; + case AV_PIX_FMT_BAYER_BGGR8: + return 155; + case AV_PIX_FMT_BAYER_RGGB8: + return 156; + case AV_PIX_FMT_BAYER_GBRG8: + return 157; + case AV_PIX_FMT_BAYER_GRBG8: + return 158; + case AV_PIX_FMT_BAYER_BGGR16LE: + return 159; + case AV_PIX_FMT_BAYER_BGGR16BE: + return 160; + case AV_PIX_FMT_BAYER_RGGB16LE: + return 161; + case AV_PIX_FMT_BAYER_RGGB16BE: + return 162; + case AV_PIX_FMT_BAYER_GBRG16LE: + return 163; + case AV_PIX_FMT_BAYER_GBRG16BE: + return 164; + case AV_PIX_FMT_BAYER_GRBG16LE: + return 165; + case AV_PIX_FMT_BAYER_GRBG16BE: + return 166; + case AV_PIX_FMT_YUV440P10LE: + return 167; + case AV_PIX_FMT_YUV440P10BE: + return 168; + case AV_PIX_FMT_YUV440P12LE: + return 169; + case AV_PIX_FMT_YUV440P12BE: + return 170; + case AV_PIX_FMT_AYUV64LE: + return 171; + case AV_PIX_FMT_AYUV64BE: + return 172; + case AV_PIX_FMT_VIDEOTOOLBOX: + return 173; + // case AV_PIX_FMT_RGB32: // IF format is this type, based on + // endianness, it resolves to big endian or small endian. + // return 175; // The Switch case contains both big and + // small endian, so No need to add these in switch case. + // case AV_PIX_FMT_RGB32_1: // Will Automatically resolve. + // return 176; + // case AV_PIX_FMT_BGR32: + // return 177; + // case AV_PIX_FMT_BGR32_1: + // return 178; + // case AV_PIX_FMT_0RGB32: + // return 179; + // case AV_PIX_FMT_0BGR32: + // return 180; + // case AV_PIX_FMT_GRAY16: + // return 181; + // case AV_PIX_FMT_YA16: + // return 182; + // case AV_PIX_FMT_RGB48: + // return 183; + // case AV_PIX_FMT_RGB565: + // return 184; + // case AV_PIX_FMT_RGB444: + // return 185; + // case AV_PIX_FMT_BGR48: + // return 186; + // case AV_PIX_FMT_BGR565: + // return 187; + // case AV_PIX_FMT_BGR555: + // return 188; + // case AV_PIX_FMT_BGR444: + // return 189; + // case AV_PIX_FMT_YUV420P9: + // return 190; + // case AV_PIX_FMT_YUV422P9: + // return 191; + // case AV_PIX_FMT_YUV444P9: + // return 192; + // case AV_PIX_FMT_YUV420P10: + // return 193; + // case AV_PIX_FMT_YUV422P10: + // return 194; + // case AV_PIX_FMT_YUV440P10: + // return 195; + // case AV_PIX_FMT_YUV444P10: + // return 196; + // case AV_PIX_FMT_YUV420P12: + // return 197; + // case AV_PIX_FMT_YUV422P12: + // return 198; + // case AV_PIX_FMT_YUV440P12: + // return 199; + // case AV_PIX_FMT_YUV444P12: + // return 200; + // case AV_PIX_FMT_YUV420P14: + // return 201; + // case AV_PIX_FMT_YUV422P14: + // return 202; + // case AV_PIX_FMT_YUV444P14: + // return 203; + // case AV_PIX_FMT_YUV420P16: + // return 204; + // case AV_PIX_FMT_YUV422P16: + // return 205; + // case AV_PIX_FMT_YUV444P16: + // return 206; + // case AV_PIX_FMT_GBRP9: + // return 207; + // case AV_PIX_FMT_GBRP10: + // return 208; + // case AV_PIX_FMT_GBRP12: + // return 209; + // case AV_PIX_FMT_GBRP14: + // return 210; + // case AV_PIX_FMT_GBRP16: + // return 211; + // case AV_PIX_FMT_GBRAP16: + // return 212; + // case AV_PIX_FMT_BAYER_BGGR16: + // return 213; + // case AV_PIX_FMT_BAYER_RGGB16: + // return 214; + // case AV_PIX_FMT_BAYER_GBRG16: + // return 215; + // case AV_PIX_FMT_BAYER_GRBG16: + // return 216; + // case AV_PIX_FMT_YUVA420P9: + // return 217; + // case AV_PIX_FMT_YUVA422P9: + // return 218; + // case AV_PIX_FMT_YUVA444P9: + // return 219; + // case AV_PIX_FMT_YUVA420P10: + // return 220; + // case AV_PIX_FMT_YUVA422P10: + // return 221; + // case AV_PIX_FMT_YUVA444P10: + // return 222; + // case AV_PIX_FMT_YUVA420P16: + // return 223; + // case AV_PIX_FMT_YUVA422P16: + // return 224; + // case AV_PIX_FMT_YUVA444P16: + // return 225; + // case AV_PIX_FMT_XYZ12: + // return 226; + // case AV_PIX_FMT_NV20: + // return 227; + // case AV_PIX_FMT_AYUV64: + // return 228; + case AV_PIX_FMT_P010LE: + return 229; + case AV_PIX_FMT_P010BE: + return 230; + case AV_PIX_FMT_GBRAP12BE: + return 231; + case AV_PIX_FMT_GBRAP12LE: + return 232; + case AV_PIX_FMT_GBRAP10LE: + return 233; + case AV_PIX_FMT_GBRAP10BE: + return 234; + case AV_PIX_FMT_MEDIACODEC: + return 235; + case AV_PIX_FMT_GRAY12BE: + return 236; + case AV_PIX_FMT_GRAY12LE: + return 237; + case AV_PIX_FMT_GRAY10BE: + return 238; + case AV_PIX_FMT_GRAY10LE: + return 239; + case AV_PIX_FMT_P016LE: + return 240; + case AV_PIX_FMT_P016BE: + return 241; + case AV_PIX_FMT_D3D11: + return 242; + case AV_PIX_FMT_GRAY9BE: + return 243; + case AV_PIX_FMT_GRAY9LE: + return 244; + case AV_PIX_FMT_GBRPF32BE: + return 245; + case AV_PIX_FMT_GBRPF32LE: + return 246; + case AV_PIX_FMT_GBRAPF32BE: + return 247; + case AV_PIX_FMT_GBRAPF32LE: + return 248; + case AV_PIX_FMT_DRM_PRIME: + return 249; + // Above ffmpeg 4.0 Need to add versions. + case AV_PIX_FMT_OPENCL: + return 250; + case AV_PIX_FMT_GRAY14BE: + return 251; + case AV_PIX_FMT_GRAY14LE: + return 252; + case AV_PIX_FMT_GRAYF32BE: + return 253; + case AV_PIX_FMT_GRAYF32LE: + return 254; + case AV_PIX_FMT_YUVA422P12BE: + return 255; + case AV_PIX_FMT_YUVA422P12LE: + return 256; + case AV_PIX_FMT_YUVA444P12BE: + return 257; + case AV_PIX_FMT_YUVA444P12LE: + return 258; + case AV_PIX_FMT_NV24: + return 259; + case AV_PIX_FMT_NV42: + return 260; + case AV_PIX_FMT_VULKAN: + return 261; + case AV_PIX_FMT_Y210BE: + return 262; + case AV_PIX_FMT_Y210LE: + return 263; + case AV_PIX_FMT_X2RGB10LE: + return 264; + case AV_PIX_FMT_X2RGB10BE: + return 265; + case AV_PIX_FMT_X2BGR10LE: + return 266; + case AV_PIX_FMT_X2BGR10BE: + return 267; + case AV_PIX_FMT_P210BE: + return 268; + case AV_PIX_FMT_P210LE: + return 269; + case AV_PIX_FMT_P410BE: + return 270; + case AV_PIX_FMT_P410LE: + return 271; + case AV_PIX_FMT_P216BE: + return 272; + case AV_PIX_FMT_P216LE: + return 273; + case AV_PIX_FMT_P416BE: + return 274; + case AV_PIX_FMT_P416LE: + return 275; + case AV_PIX_FMT_VUYA: + return 276; + case AV_PIX_FMT_RGBAF16BE: + return 277; + case AV_PIX_FMT_RGBAF16LE: + return 278; + case AV_PIX_FMT_VUYX: + return 279; + case AV_PIX_FMT_P012LE: + return 280; + case AV_PIX_FMT_P012BE: + return 281; + case AV_PIX_FMT_Y212BE: + return 282; + case AV_PIX_FMT_Y212LE: + return 283; + case AV_PIX_FMT_XV30BE: + return 284; + case AV_PIX_FMT_XV30LE: + return 285; + case AV_PIX_FMT_XV36BE: + return 286; + case AV_PIX_FMT_XV36LE: + return 287; + case AV_PIX_FMT_RGBF32BE: + return 288; + case AV_PIX_FMT_RGBF32LE: + return 289; + case AV_PIX_FMT_RGBAF32BE: + return 290; + case AV_PIX_FMT_RGBAF32LE: + return 291; + // case AV_PIX_FMT_RPI : + // return 292; + // case AV_PIX_FMT_SAND128 : + // return 293; + // case AV_PIX_FMT_SAND64_10 : + // return 294; + // case AV_PIX_FMT_SAND64_16 : + // return 295; + // case AV_PIX_FMT_RPI4_8 : // rpi turn on then only + // return 296; + // case AV_PIX_FMT_RPI4_10 : + // return 297; + // case AV_PIX_FMT_RGB555: // Little Endian, Big Endian WIll + // Resolve on it's own. + // return 298; + default: + return 0; + } + } + + static AVPixelFormat intoAVPixFmt(uint32_t AvPixFmtId) { + switch (AvPixFmtId) { + case 0: + return AV_PIX_FMT_NONE; + case 1: + return AV_PIX_FMT_YUV420P; + case 2: + return AV_PIX_FMT_YUYV422; + case 3: + return AV_PIX_FMT_RGB24; + case 4: + return AV_PIX_FMT_BGR24; + case 5: + return AV_PIX_FMT_YUV422P; + case 7: + return AV_PIX_FMT_YUV444P; + case 8: + return AV_PIX_FMT_YUV410P; + case 9: + return AV_PIX_FMT_YUV411P; + case 10: + return AV_PIX_FMT_GRAY8; + case 11: + return AV_PIX_FMT_MONOWHITE; + case 12: + return AV_PIX_FMT_MONOBLACK; + case 13: + return AV_PIX_FMT_PAL8; + case 14: + return AV_PIX_FMT_YUVJ420P; + case 15: + return AV_PIX_FMT_YUVJ422P; + case 16: + return AV_PIX_FMT_YUVJ444P; + // case 17: + // return AV_PIX_FMT_XVMC_MPEG2_MC ; // Lower FFmpeg Version + // case 18: + // return AV_PIX_FMT_XVMC_MPEG2_IDCT ; + case 19: + return AV_PIX_FMT_UYVY422; + case 20: + return AV_PIX_FMT_UYYVYY411; + case 21: + return AV_PIX_FMT_BGR8; + case 22: + return AV_PIX_FMT_BGR4; + case 23: + return AV_PIX_FMT_BGR4_BYTE; + case 24: + return AV_PIX_FMT_RGB8; + case 25: + return AV_PIX_FMT_RGB4; + case 26: + return AV_PIX_FMT_RGB4_BYTE; + case 27: + return AV_PIX_FMT_NV12; + case 28: + return AV_PIX_FMT_NV21; + case 29: + return AV_PIX_FMT_ARGB; + case 30: + return AV_PIX_FMT_RGBA; + case 31: + return AV_PIX_FMT_ABGR; + case 32: + return AV_PIX_FMT_BGRA; // Little + case 33: + return AV_PIX_FMT_GRAY16BE; + case 34: + return AV_PIX_FMT_GRAY16LE; + case 35: + return AV_PIX_FMT_YUV440P; + case 36: + return AV_PIX_FMT_YUVJ440P; + case 37: + return AV_PIX_FMT_YUVA420P; + // case 38: + // return AV_PIX_FMT_VDPAU_H264 ; + // case 39: + // return AV_PIX_FMT_VDPAU_MPEG1 ; + // case 40: + // return AV_PIX_FMT_VDPAU_MPEG2 ; + // case 41: + // return AV_PIX_FMT_VDPAU_WMV3 ; // Conditional compile. + // case 42: + // return AV_PIX_FMT_VDPAU_VC1 ; // ff_api_vdpau is present + case 43: + return AV_PIX_FMT_RGB48BE; + case 44: + return AV_PIX_FMT_RGB48LE; + case 45: + return AV_PIX_FMT_RGB565BE; + case 46: + return AV_PIX_FMT_RGB565LE; + case 47: + return AV_PIX_FMT_RGB555BE; + case 48: + return AV_PIX_FMT_RGB555LE; + case 49: + return AV_PIX_FMT_BGR565BE; + case 50: + return AV_PIX_FMT_BGR565LE; + case 51: + return AV_PIX_FMT_BGR555BE; + case 52: + return AV_PIX_FMT_BGR555LE; + // case 53: + // return AV_PIX_FMT_VAAPI_MOCO ; + // case 54: + // return AV_PIX_FMT_VAAPI_IDCT ; + // case 55: + // return AV_PIX_FMT_VAAPI_VLD ; + // case 56: + // return AV_PIX_FMT_VAAPI ; // ff_api_vdpau is present + case 57: + return AV_PIX_FMT_YUV420P16LE; + case 58: + return AV_PIX_FMT_YUV420P16BE; + case 59: + return AV_PIX_FMT_YUV422P16LE; + case 60: + return AV_PIX_FMT_YUV422P16BE; + case 61: + return AV_PIX_FMT_YUV444P16LE; + case 62: + return AV_PIX_FMT_YUV444P16BE; + // case 63: + // return AV_PIX_FMT_VDPAU_MPEG4 ; // ff_api_vdpau is present + case 64: + return AV_PIX_FMT_DXVA2_VLD; + case 65: + return AV_PIX_FMT_RGB444LE; + case 66: + return AV_PIX_FMT_RGB444BE; + case 67: + return AV_PIX_FMT_BGR444LE; + case 68: + return AV_PIX_FMT_BGR444BE; + case 69: + return AV_PIX_FMT_YA8; + case 70: + return AV_PIX_FMT_BGR48BE; + case 71: + return AV_PIX_FMT_BGR48LE; + case 72: + return AV_PIX_FMT_YUV420P9BE; + case 73: + return AV_PIX_FMT_YUV420P9LE; + case 74: + return AV_PIX_FMT_YUV420P10BE; + case 75: + return AV_PIX_FMT_YUV420P10LE; + case 76: + return AV_PIX_FMT_YUV422P10BE; + case 77: + return AV_PIX_FMT_YUV422P10LE; + case 78: + return AV_PIX_FMT_YUV444P9BE; + case 79: + return AV_PIX_FMT_YUV444P9LE; + case 80: + return AV_PIX_FMT_YUV444P10BE; + case 81: + return AV_PIX_FMT_YUV444P10LE; + case 82: + return AV_PIX_FMT_YUV422P9BE; + case 83: + return AV_PIX_FMT_YUV422P9LE; + // case 84: + // return AV_PIX_FMT_VDA_VLD ; // Lower than ffmpeg version 4 + case 85: + return AV_PIX_FMT_GBRP; + case 86: + return AV_PIX_FMT_GBRP9BE; + case 87: + return AV_PIX_FMT_GBRP9LE; + case 88: + return AV_PIX_FMT_GBRP10BE; + case 89: + return AV_PIX_FMT_GBRP10LE; + case 90: + return AV_PIX_FMT_GBRP16BE; + case 91: + return AV_PIX_FMT_GBRP16LE; + case 92: + return AV_PIX_FMT_YUVA420P9BE; + case 93: + return AV_PIX_FMT_YUVA420P9LE; + case 94: + return AV_PIX_FMT_YUVA422P9BE; + case 95: + return AV_PIX_FMT_YUVA422P9LE; + case 96: + return AV_PIX_FMT_YUVA444P9BE; + case 97: + return AV_PIX_FMT_YUVA444P9LE; + case 98: + return AV_PIX_FMT_YUVA420P10BE; + case 99: + return AV_PIX_FMT_YUVA420P10LE; + case 100: + return AV_PIX_FMT_YUVA422P10BE; + case 101: + return AV_PIX_FMT_YUVA422P10LE; + case 102: + return AV_PIX_FMT_YUVA444P10BE; + case 103: + return AV_PIX_FMT_YUVA444P10LE; + case 104: + return AV_PIX_FMT_YUVA420P16BE; + case 105: + return AV_PIX_FMT_YUVA420P16LE; + case 106: + return AV_PIX_FMT_YUVA422P16BE; + case 107: + return AV_PIX_FMT_YUVA422P16LE; + case 108: + return AV_PIX_FMT_YUVA444P16BE; + case 109: + return AV_PIX_FMT_YUVA444P16LE; + case 110: + return AV_PIX_FMT_VDPAU; + case 111: + return AV_PIX_FMT_XYZ12LE; + case 112: + return AV_PIX_FMT_XYZ12BE; + case 113: + return AV_PIX_FMT_NV16; + case 114: + return AV_PIX_FMT_NV20LE; + case 115: + return AV_PIX_FMT_NV20BE; + case 116: + return AV_PIX_FMT_RGBA64BE; + case 117: + return AV_PIX_FMT_RGBA64LE; + case 118: + return AV_PIX_FMT_BGRA64BE; + case 119: + return AV_PIX_FMT_BGRA64LE; + case 120: + return AV_PIX_FMT_YVYU422; + // case 121: + // return AV_PIX_FMT_VDA ; // Lower than ffmpeg version 4. + case 122: + return AV_PIX_FMT_YA16BE; + case 123: + return AV_PIX_FMT_YA16LE; + case 124: + return AV_PIX_FMT_QSV; + case 125: + return AV_PIX_FMT_MMAL; + case 126: + return AV_PIX_FMT_D3D11VA_VLD; + case 127: + return AV_PIX_FMT_CUDA; + case 128: + return AV_PIX_FMT_0RGB; + case 129: + return AV_PIX_FMT_RGB0; + case 130: + return AV_PIX_FMT_0BGR; + case 131: + return AV_PIX_FMT_BGR0; + case 132: + return AV_PIX_FMT_YUVA444P; + case 133: + return AV_PIX_FMT_YUVA422P; + case 134: + return AV_PIX_FMT_YUV420P12BE; + case 135: + return AV_PIX_FMT_YUV420P12LE; + case 136: + return AV_PIX_FMT_YUV420P14BE; + case 137: + return AV_PIX_FMT_YUV420P14LE; + case 138: + return AV_PIX_FMT_YUV422P12BE; + case 139: + return AV_PIX_FMT_YUV422P12LE; + case 140: + return AV_PIX_FMT_YUV422P14BE; + case 141: + return AV_PIX_FMT_YUV422P14LE; + case 142: + return AV_PIX_FMT_YUV444P12BE; + case 143: + return AV_PIX_FMT_YUV444P12LE; + case 144: + return AV_PIX_FMT_YUV444P14BE; + case 146: + return AV_PIX_FMT_YUV444P14LE; + case 147: + return AV_PIX_FMT_GBRP12BE; + case 148: + return AV_PIX_FMT_GBRP12LE; + case 149: + return AV_PIX_FMT_GBRP14BE; + case 150: + return AV_PIX_FMT_GBRP14LE; + case 151: + return AV_PIX_FMT_GBRAP; + case 152: + return AV_PIX_FMT_GBRAP16BE; + case 153: + return AV_PIX_FMT_GBRAP16LE; + case 154: + return AV_PIX_FMT_YUVJ411P; + case 155: + return AV_PIX_FMT_BAYER_BGGR8; + case 156: + return AV_PIX_FMT_BAYER_RGGB8; + case 157: + return AV_PIX_FMT_BAYER_GBRG8; + case 158: + return AV_PIX_FMT_BAYER_GRBG8; + case 159: + return AV_PIX_FMT_BAYER_BGGR16LE; + case 160: + return AV_PIX_FMT_BAYER_BGGR16BE; + case 161: + return AV_PIX_FMT_BAYER_RGGB16LE; + case 162: + return AV_PIX_FMT_BAYER_RGGB16BE; + case 163: + return AV_PIX_FMT_BAYER_GBRG16LE; + case 164: + return AV_PIX_FMT_BAYER_GBRG16BE; + case 165: + return AV_PIX_FMT_BAYER_GRBG16LE; + case 166: + return AV_PIX_FMT_BAYER_GRBG16BE; + case 167: + return AV_PIX_FMT_YUV440P10LE; + case 168: + return AV_PIX_FMT_YUV440P10BE; + case 169: + return AV_PIX_FMT_YUV440P12LE; + case 170: + return AV_PIX_FMT_YUV440P12BE; + case 171: + return AV_PIX_FMT_AYUV64LE; + case 172: + return AV_PIX_FMT_AYUV64BE; + case 173: + return AV_PIX_FMT_VIDEOTOOLBOX; + case 175: + return AV_PIX_FMT_RGB32; + case 176: + return AV_PIX_FMT_RGB32_1; + case 177: + return AV_PIX_FMT_BGR32; + case 178: + return AV_PIX_FMT_BGR32_1; + case 179: + return AV_PIX_FMT_0RGB32; + case 180: + return AV_PIX_FMT_0BGR32; + case 181: + return AV_PIX_FMT_GRAY16; + case 182: + return AV_PIX_FMT_YA16; + case 183: + return AV_PIX_FMT_RGB48; + case 184: + return AV_PIX_FMT_RGB565; + case 185: + return AV_PIX_FMT_RGB444; + case 186: + return AV_PIX_FMT_BGR48; + case 187: + return AV_PIX_FMT_BGR565; + case 188: + return AV_PIX_FMT_BGR555; + case 189: + return AV_PIX_FMT_BGR444; + case 190: + return AV_PIX_FMT_YUV420P9; + case 191: + return AV_PIX_FMT_YUV422P9; + case 192: + return AV_PIX_FMT_YUV444P9; + case 193: + return AV_PIX_FMT_YUV420P10; + case 194: + return AV_PIX_FMT_YUV422P10; + case 195: + return AV_PIX_FMT_YUV440P10; + case 196: + return AV_PIX_FMT_YUV444P10; + case 197: + return AV_PIX_FMT_YUV420P12; + case 198: + return AV_PIX_FMT_YUV422P12; + case 199: + return AV_PIX_FMT_YUV440P12; + case 200: + return AV_PIX_FMT_YUV444P12; + case 201: + return AV_PIX_FMT_YUV420P14; + case 202: + return AV_PIX_FMT_YUV422P14; + case 203: + return AV_PIX_FMT_YUV444P14; + case 204: + return AV_PIX_FMT_YUV420P16; + case 205: + return AV_PIX_FMT_YUV422P16; + case 206: + return AV_PIX_FMT_YUV444P16; + case 207: + return AV_PIX_FMT_GBRP9; + case 208: + return AV_PIX_FMT_GBRP10; + case 209: + return AV_PIX_FMT_GBRP12; + case 210: + return AV_PIX_FMT_GBRP14; + case 211: + return AV_PIX_FMT_GBRP16; + case 212: + return AV_PIX_FMT_GBRAP16; + case 213: + return AV_PIX_FMT_BAYER_BGGR16; + case 214: + return AV_PIX_FMT_BAYER_RGGB16; + case 215: + return AV_PIX_FMT_BAYER_GBRG16; + case 216: + return AV_PIX_FMT_BAYER_GRBG16; + case 217: + return AV_PIX_FMT_YUVA420P9; + case 218: + return AV_PIX_FMT_YUVA422P9; + case 219: + return AV_PIX_FMT_YUVA444P9; + case 220: + return AV_PIX_FMT_YUVA420P10; + case 221: + return AV_PIX_FMT_YUVA422P10; + case 222: + return AV_PIX_FMT_YUVA444P10; + case 223: + return AV_PIX_FMT_YUVA420P16; + case 224: + return AV_PIX_FMT_YUVA422P16; + case 225: + return AV_PIX_FMT_YUVA444P16; + case 226: + return AV_PIX_FMT_XYZ12; + case 227: + return AV_PIX_FMT_NV20; + case 228: + return AV_PIX_FMT_AYUV64; + case 229: + return AV_PIX_FMT_P010LE; + case 230: + return AV_PIX_FMT_P010BE; + case 231: + return AV_PIX_FMT_GBRAP12BE; + case 232: + return AV_PIX_FMT_GBRAP12LE; + case 233: + return AV_PIX_FMT_GBRAP10LE; + case 234: + return AV_PIX_FMT_GBRAP10BE; + case 235: + return AV_PIX_FMT_MEDIACODEC; + case 236: + return AV_PIX_FMT_GRAY12BE; + case 237: + return AV_PIX_FMT_GRAY12LE; + case 238: + return AV_PIX_FMT_GRAY10BE; + case 239: + return AV_PIX_FMT_GRAY10LE; + case 240: + return AV_PIX_FMT_P016LE; + case 241: + return AV_PIX_FMT_P016BE; + case 242: + return AV_PIX_FMT_D3D11; + case 243: + return AV_PIX_FMT_GRAY9BE; + case 244: + return AV_PIX_FMT_GRAY9LE; + case 245: + return AV_PIX_FMT_GBRPF32BE; + case 246: + return AV_PIX_FMT_GBRPF32LE; + case 247: + return AV_PIX_FMT_GBRAPF32BE; + case 248: + return AV_PIX_FMT_GBRAPF32LE; + case 249: + return AV_PIX_FMT_DRM_PRIME; + + // Above ffmpeg 4.0 Need to add versions. + case 250: + return AV_PIX_FMT_OPENCL; + case 251: + return AV_PIX_FMT_GRAY14BE; + case 252: + return AV_PIX_FMT_GRAY14LE; + case 253: + return AV_PIX_FMT_GRAYF32BE; + case 254: + return AV_PIX_FMT_GRAYF32LE; + case 255: + return AV_PIX_FMT_YUVA422P12BE; + case 256: + return AV_PIX_FMT_YUVA422P12LE; + case 257: + return AV_PIX_FMT_YUVA444P12BE; + case 258: + return AV_PIX_FMT_YUVA444P12LE; + case 259: + return AV_PIX_FMT_NV24; + case 260: + return AV_PIX_FMT_NV42; + case 261: + return AV_PIX_FMT_VULKAN; + case 262: + return AV_PIX_FMT_Y210BE; + case 263: + return AV_PIX_FMT_Y210LE; + case 264: + return AV_PIX_FMT_X2RGB10LE; + case 265: + return AV_PIX_FMT_X2RGB10BE; + case 266: + return AV_PIX_FMT_X2BGR10LE; + case 267: + return AV_PIX_FMT_X2BGR10BE; + case 268: + return AV_PIX_FMT_P210BE; + case 269: + return AV_PIX_FMT_P210LE; + case 270: + return AV_PIX_FMT_P410BE; + case 271: + return AV_PIX_FMT_P410LE; + case 272: + return AV_PIX_FMT_P216BE; + case 273: + return AV_PIX_FMT_P216LE; + case 274: + return AV_PIX_FMT_P416BE; + case 275: + return AV_PIX_FMT_P416LE; + case 276: + return AV_PIX_FMT_VUYA; + case 277: + return AV_PIX_FMT_RGBAF16BE; + case 278: + return AV_PIX_FMT_RGBAF16LE; + case 279: + return AV_PIX_FMT_VUYX; + case 280: + return AV_PIX_FMT_P012LE; + case 281: + return AV_PIX_FMT_P012BE; + case 282: + return AV_PIX_FMT_Y212BE; + case 283: + return AV_PIX_FMT_Y212LE; + case 284: + return AV_PIX_FMT_XV30BE; + case 285: + return AV_PIX_FMT_XV30LE; + case 286: + return AV_PIX_FMT_XV36BE; + case 287: + return AV_PIX_FMT_XV36LE; + case 288: + return AV_PIX_FMT_RGBF32BE; + case 289: + return AV_PIX_FMT_RGBF32LE; + case 290: + return AV_PIX_FMT_RGBAF32BE; + case 291: + return AV_PIX_FMT_RGBAF32LE; + // case 292: + // return AV_PIX_FMT_RPI ; + // case 293: + // return AV_PIX_FMT_SAND128 ; + // case 294: + // return AV_PIX_FMT_SAND64_10 ; + // case 295: + // return AV_PIX_FMT_SAND64_16 ; + // case 296: + // return AV_PIX_FMT_RPI4_8 ; // rpi turn on then only + // case 297: + // return AV_PIX_FMT_RPI4_10 ; + case 298: + return AV_PIX_FMT_RGB555; + default: + return AV_PIX_FMT_NONE; + } + } +}; + +class SampleFmt { +public: + static AVSampleFormat fromSampleID(uint32_t SampleID) { + switch (SampleID) { + case 0: + return AV_SAMPLE_FMT_NONE; + case 1: + return AV_SAMPLE_FMT_U8; + case 2: + return AV_SAMPLE_FMT_S16; + case 3: + return AV_SAMPLE_FMT_S32; + case 4: + return AV_SAMPLE_FMT_FLT; + case 5: + return AV_SAMPLE_FMT_DBL; + case 6: + return AV_SAMPLE_FMT_U8P; + case 7: + return AV_SAMPLE_FMT_S16P; + case 8: + return AV_SAMPLE_FMT_S32P; + case 9: + return AV_SAMPLE_FMT_FLTP; + case 10: + return AV_SAMPLE_FMT_DBLP; + case 11: + return AV_SAMPLE_FMT_S64; + case 12: + return AV_SAMPLE_FMT_S64P; + case 13: + return AV_SAMPLE_FMT_NB; + default: + return AV_SAMPLE_FMT_NONE; + } + } + + static uint32_t toSampleID(AVSampleFormat AvSampleFormat) { + switch (AvSampleFormat) { + case AV_SAMPLE_FMT_NONE: + return 0; + case AV_SAMPLE_FMT_U8: + return 1; + case AV_SAMPLE_FMT_S16: + return 2; + case AV_SAMPLE_FMT_S32: + return 3; + case AV_SAMPLE_FMT_FLT: + return 4; + case AV_SAMPLE_FMT_DBL: + return 5; + case AV_SAMPLE_FMT_U8P: + return 6; + case AV_SAMPLE_FMT_S16P: + return 7; + case AV_SAMPLE_FMT_S32P: + return 8; + case AV_SAMPLE_FMT_FLTP: + return 9; + case AV_SAMPLE_FMT_DBLP: + return 10; + case AV_SAMPLE_FMT_S64: + return 11; + case AV_SAMPLE_FMT_S64P: + return 12; + case AV_SAMPLE_FMT_NB: + return 13; + default: + return 0; + } + } +}; + +// Could have avoided, but Did this to support older version of FFMPEG +// (V5,V4,V3) Version 6 FFmpeg uses AVChannelLayout Struct; +class ChannelLayout { +private: + const static uint64_t FRONT_LEFT = 1; + const static uint64_t FRONT_RIGHT = 1ULL << 1; + const static uint64_t FRONT_CENTER = 1ULL << 2; + const static uint64_t LOW_FREQUENCY = 1ULL << 3; + const static uint64_t BACK_LEFT = 1ULL << 4; + const static uint64_t BACK_RIGHT = 1ULL << 5; + const static uint64_t FRONT_LEFT_OF_CENTER = 1ULL << 6; + const static uint64_t FRONT_RIGHT_OF_CENTER = 1ULL << 7; + const static uint64_t BACK_CENTER = 1ULL << 8; + const static uint64_t SIDE_LEFT = 1ULL << 9; + const static uint64_t SIDE_RIGHT = 1ULL << 10; + const static uint64_t TOP_CENTER = 1ULL << 11; + const static uint64_t TOP_FRONT_LEFT = 1ULL << 12; + const static uint64_t TOP_FRONT_CENTER = 1ULL << 13; + const static uint64_t TOP_FRONT_RIGHT = 1ULL << 14; + const static uint64_t TOP_BACK_LEFT = 1ULL << 15; + const static uint64_t TOP_BACK_CENTER = 1ULL << 16; + const static uint64_t TOP_BACK_RIGHT = 1ULL << 17; + const static uint64_t STEREO_LEFT = 1ULL << 18; + const static uint64_t STEREO_RIGHT = 1ULL << 19; + const static uint64_t WIDE_LEFT = 1ULL << 20; + const static uint64_t WIDE_RIGHT = 1ULL << 21; + const static uint64_t SURROUND_DIRECT_LEFT = 1ULL << 22; + const static uint64_t SURROUND_DIRECT_RIGHT = 1ULL << 23; + const static uint64_t LOW_FREQUENCY_2 = 1ULL << 24; + + const static uint64_t MONO = 1ULL << 26; + const static uint64_t STEREO = 1ULL << 27; + const static uint64_t _2POINT1 = 1ULL << 28; + const static uint64_t _2_1 = 1ULL << 29; + const static uint64_t SURROUND = 1ULL << 30; + const static uint64_t _3POINT1 = 1ULL << 31; + const static uint64_t _4POINT0 = 1ULL << 32; + const static uint64_t _4POINT1 = 1ULL << 33; + const static uint64_t _2_2 = 1ULL << 34; + const static uint64_t QUAD = 1ULL << 35; + const static uint64_t _5POINT0 = 1ULL << 36; + const static uint64_t _5POINT1 = 1ULL << 37; + const static uint64_t _5POINT0_BACK = 1ULL << 38; + const static uint64_t _5POINT1_BACK = 1ULL << 39; + const static uint64_t _6POINT0 = 1ULL << 40; + const static uint64_t _6POINT0_FRONT = 1ULL << 41; + const static uint64_t HEXAGONAL = 1ULL << 42; + const static uint64_t _6POINT1 = 1ULL << 43; + const static uint64_t _6POINT1_BACK = 1ULL << 44; + const static uint64_t _6POINT1_FRONT = 1ULL << 45; + const static uint64_t _7POINT0 = 1ULL << 46; + const static uint64_t _7POINT0_FRONT = 1ULL << 47; + const static uint64_t _7POINT1 = 1ULL << 48; + const static uint64_t _7POINT1_WIDE = 1ULL << 49; + const static uint64_t _7POINT1_WIDE_BACK = 1ULL << 50; + const static uint64_t OCTAGONAL = 1ULL << 51; + const static uint64_t HEXADECAGONAL = 1ULL << 52; + const static uint64_t STEREO_DOWNMIX = 1ULL << 53; + +public: + // Check This function. (Looks good, test it) + static uint64_t fromChannelLayoutID(uint64_t ChannelLayout) { + uint64_t Channel = 0UL; + if (ChannelLayout & FRONT_LEFT) { + Channel |= AV_CH_FRONT_LEFT; + } + if (ChannelLayout & FRONT_RIGHT) { + Channel |= AV_CH_FRONT_RIGHT; + } + if (ChannelLayout & FRONT_CENTER) { + Channel |= AV_CH_FRONT_CENTER; + } + if (ChannelLayout & LOW_FREQUENCY) { + Channel |= AV_CH_LOW_FREQUENCY; + } + if (ChannelLayout & BACK_LEFT) { + Channel |= AV_CH_BACK_LEFT; + } + if (ChannelLayout & BACK_RIGHT) { + Channel |= AV_CH_BACK_RIGHT; + } + if (ChannelLayout & FRONT_LEFT_OF_CENTER) { + Channel |= AV_CH_FRONT_LEFT_OF_CENTER; + } + if (ChannelLayout & FRONT_RIGHT_OF_CENTER) { + Channel |= AV_CH_FRONT_RIGHT_OF_CENTER; + } + if (ChannelLayout & BACK_CENTER) { + Channel |= AV_CH_BACK_CENTER; + } + if (ChannelLayout & SIDE_LEFT) { + Channel |= AV_CH_SIDE_LEFT; + } + if (ChannelLayout & SIDE_RIGHT) { + Channel |= AV_CH_SIDE_RIGHT; + } + if (ChannelLayout & TOP_CENTER) { + Channel |= AV_CH_TOP_CENTER; + } + if (ChannelLayout & TOP_FRONT_LEFT) { + Channel |= AV_CH_TOP_FRONT_LEFT; + } + if (ChannelLayout & TOP_FRONT_CENTER) { + Channel |= AV_CH_TOP_FRONT_CENTER; + } + if (ChannelLayout & TOP_FRONT_RIGHT) { + Channel |= AV_CH_TOP_FRONT_RIGHT; + } + if (ChannelLayout & TOP_BACK_LEFT) { + Channel |= AV_CH_TOP_BACK_LEFT; + } + if (ChannelLayout & TOP_BACK_CENTER) { + Channel |= AV_CH_TOP_BACK_CENTER; + } + if (ChannelLayout & TOP_BACK_RIGHT) { + Channel |= AV_CH_TOP_BACK_RIGHT; + } + if (ChannelLayout & STEREO_LEFT) { + Channel |= AV_CH_STEREO_LEFT; + } + if (ChannelLayout & STEREO_RIGHT) { + Channel |= AV_CH_STEREO_RIGHT; + } + if (ChannelLayout & WIDE_LEFT) { + Channel |= AV_CH_WIDE_LEFT; + } + if (ChannelLayout & WIDE_RIGHT) { + Channel |= AV_CH_WIDE_RIGHT; + } + if (ChannelLayout & SURROUND_DIRECT_LEFT) { + Channel |= AV_CH_SURROUND_DIRECT_LEFT; + } + if (ChannelLayout & SURROUND_DIRECT_RIGHT) { + Channel |= AV_CH_SURROUND_DIRECT_RIGHT; + } + if (ChannelLayout & LOW_FREQUENCY_2) { + Channel |= AV_CH_LOW_FREQUENCY_2; + } + if (ChannelLayout & MONO) { + Channel |= AV_CH_LAYOUT_MONO; + } + if (ChannelLayout & STEREO) { + Channel |= AV_CH_LAYOUT_STEREO; + } + if (ChannelLayout & _2POINT1) { + Channel |= AV_CH_LAYOUT_2POINT1; + } + if (ChannelLayout & _2_1) { + Channel |= AV_CH_LAYOUT_2_1; + } + if (ChannelLayout & SURROUND) { + Channel |= AV_CH_LAYOUT_SURROUND; + } + if (ChannelLayout & _3POINT1) { + Channel |= AV_CH_LAYOUT_3POINT1; + } + if (ChannelLayout & _4POINT0) { + Channel |= AV_CH_LAYOUT_4POINT0; + } + if (ChannelLayout & _4POINT1) { + Channel |= AV_CH_LAYOUT_4POINT1; + } + if (ChannelLayout & _2_2) { + Channel |= AV_CH_LAYOUT_2_2; + } + if (ChannelLayout & QUAD) { + Channel |= AV_CH_LAYOUT_QUAD; + } + if (ChannelLayout & _5POINT0) { + Channel |= AV_CH_LAYOUT_5POINT0; + } + if (ChannelLayout & _5POINT1) { + Channel |= AV_CH_LAYOUT_5POINT1; + } + if (ChannelLayout & _5POINT0_BACK) { + Channel |= AV_CH_LAYOUT_5POINT0_BACK; + } + if (ChannelLayout & _5POINT1_BACK) { + Channel |= AV_CH_LAYOUT_5POINT1_BACK; + } + if (ChannelLayout & _6POINT0) { + Channel |= AV_CH_LAYOUT_6POINT0; + } + if (ChannelLayout & _6POINT0_FRONT) { + Channel |= AV_CH_LAYOUT_6POINT0_FRONT; + } + if (ChannelLayout & HEXAGONAL) { + Channel |= AV_CH_LAYOUT_HEXAGONAL; + } + if (ChannelLayout & _6POINT1) { + Channel |= AV_CH_LAYOUT_6POINT1; + } + if (ChannelLayout & _6POINT1_BACK) { + Channel |= AV_CH_LAYOUT_6POINT1_BACK; + } + if (ChannelLayout & _6POINT1_FRONT) { + Channel |= AV_CH_LAYOUT_6POINT1_FRONT; + } + if (ChannelLayout & _7POINT0) { + Channel |= AV_CH_LAYOUT_7POINT0; + } + if (ChannelLayout & _7POINT0_FRONT) { + Channel |= AV_CH_LAYOUT_7POINT0_FRONT; + } + if (ChannelLayout & _7POINT1) { + Channel |= AV_CH_LAYOUT_7POINT1; + } + if (ChannelLayout & _7POINT1_WIDE) { + Channel |= AV_CH_LAYOUT_7POINT1_WIDE; + } + if (ChannelLayout & _7POINT1_WIDE_BACK) { + Channel |= AV_CH_LAYOUT_7POINT1_WIDE_BACK; + } + if (ChannelLayout & OCTAGONAL) { + Channel |= AV_CH_LAYOUT_OCTAGONAL; + } + if (ChannelLayout & HEXADECAGONAL) { + Channel |= AV_CH_LAYOUT_HEXADECAGONAL; + } + if (ChannelLayout & STEREO_DOWNMIX) { + Channel |= AV_CH_LAYOUT_STEREO_DOWNMIX; + } + return Channel; + } + + // Perfect Logic :) + static uint64_t intoChannelLayoutID(uint64_t ChannelLayout) { + uint64_t Channel = 0; + if ((ChannelLayout & AV_CH_FRONT_LEFT) == AV_CH_FRONT_LEFT) { + Channel |= FRONT_LEFT; + } + if ((ChannelLayout & AV_CH_FRONT_RIGHT) == AV_CH_FRONT_RIGHT) { + Channel |= FRONT_RIGHT; + } + if ((ChannelLayout & AV_CH_FRONT_CENTER) == AV_CH_FRONT_CENTER) { + Channel |= FRONT_CENTER; + } + if ((ChannelLayout & AV_CH_LOW_FREQUENCY) == AV_CH_LOW_FREQUENCY) { + Channel |= LOW_FREQUENCY; + } + if ((ChannelLayout & AV_CH_BACK_LEFT) == AV_CH_BACK_LEFT) { + Channel |= BACK_LEFT; + } + if ((ChannelLayout & AV_CH_BACK_RIGHT) == AV_CH_BACK_RIGHT) { + Channel |= BACK_RIGHT; + } + if ((ChannelLayout & AV_CH_FRONT_LEFT_OF_CENTER) == + AV_CH_FRONT_LEFT_OF_CENTER) { + Channel |= FRONT_LEFT_OF_CENTER; + } + if ((ChannelLayout & AV_CH_FRONT_RIGHT_OF_CENTER) == + AV_CH_FRONT_RIGHT_OF_CENTER) { + Channel |= FRONT_RIGHT_OF_CENTER; + } + if ((ChannelLayout & AV_CH_BACK_CENTER) == AV_CH_BACK_CENTER) { + Channel |= BACK_CENTER; + } + if ((ChannelLayout & AV_CH_SIDE_LEFT) == AV_CH_SIDE_LEFT) { + Channel |= SIDE_LEFT; + } + if ((ChannelLayout & AV_CH_SIDE_RIGHT) == AV_CH_SIDE_RIGHT) { + Channel |= SIDE_RIGHT; + } + if ((ChannelLayout & AV_CH_TOP_CENTER) == AV_CH_TOP_CENTER) { + Channel |= TOP_CENTER; + } + if ((ChannelLayout & AV_CH_TOP_FRONT_LEFT) == AV_CH_TOP_FRONT_LEFT) { + Channel |= TOP_FRONT_LEFT; + } + if ((ChannelLayout & AV_CH_TOP_FRONT_CENTER) == AV_CH_TOP_FRONT_CENTER) { + Channel |= TOP_FRONT_CENTER; + } + if ((ChannelLayout & AV_CH_TOP_FRONT_RIGHT) == AV_CH_TOP_FRONT_RIGHT) { + Channel |= TOP_FRONT_RIGHT; + } + if ((ChannelLayout & AV_CH_TOP_BACK_LEFT) == AV_CH_TOP_BACK_LEFT) { + Channel |= TOP_BACK_LEFT; + } + if ((ChannelLayout & AV_CH_TOP_BACK_CENTER) == AV_CH_TOP_BACK_CENTER) { + Channel |= TOP_BACK_CENTER; + } + if ((ChannelLayout & AV_CH_TOP_BACK_RIGHT) == AV_CH_TOP_BACK_RIGHT) { + Channel |= TOP_BACK_RIGHT; + } + if ((ChannelLayout & AV_CH_STEREO_LEFT) == AV_CH_STEREO_LEFT) { + Channel |= STEREO_LEFT; + } + if ((ChannelLayout & AV_CH_STEREO_RIGHT) == AV_CH_STEREO_RIGHT) { + Channel |= STEREO_RIGHT; + } + if ((ChannelLayout & AV_CH_WIDE_LEFT) == AV_CH_WIDE_LEFT) { + Channel |= WIDE_LEFT; + } + if ((ChannelLayout & AV_CH_WIDE_RIGHT) == AV_CH_WIDE_RIGHT) { + Channel |= WIDE_RIGHT; + } + if ((ChannelLayout & AV_CH_SURROUND_DIRECT_LEFT) == + AV_CH_SURROUND_DIRECT_LEFT) { + Channel |= SURROUND_DIRECT_LEFT; + } + if ((ChannelLayout & AV_CH_SURROUND_DIRECT_RIGHT) == + AV_CH_SURROUND_DIRECT_RIGHT) { + Channel |= SURROUND_DIRECT_RIGHT; + } + if ((ChannelLayout & AV_CH_LOW_FREQUENCY_2) == AV_CH_LOW_FREQUENCY_2) { + Channel |= LOW_FREQUENCY_2; + } + + // Channel Mask C; + if ((ChannelLayout & AV_CH_LAYOUT_MONO) == AV_CH_LAYOUT_MONO) { + Channel |= MONO; + } + if ((ChannelLayout & AV_CH_LAYOUT_STEREO) == AV_CH_LAYOUT_STEREO) { + Channel |= STEREO; + } + if ((ChannelLayout & AV_CH_LAYOUT_2POINT1) == AV_CH_LAYOUT_2POINT1) { + Channel |= _2POINT1; + } + if ((ChannelLayout & AV_CH_LAYOUT_2_1) == AV_CH_LAYOUT_2_1) { + Channel |= _2_1; + } + if ((ChannelLayout & AV_CH_LAYOUT_SURROUND) == AV_CH_LAYOUT_SURROUND) { + Channel |= SURROUND; + } + if ((ChannelLayout & AV_CH_LAYOUT_3POINT1) == AV_CH_LAYOUT_3POINT1) { + Channel |= _3POINT1; + } + if ((ChannelLayout & AV_CH_LAYOUT_4POINT0) == AV_CH_LAYOUT_4POINT0) { + Channel |= _4POINT0; + } + if ((ChannelLayout & AV_CH_LAYOUT_4POINT1) == AV_CH_LAYOUT_4POINT1) { + Channel |= _4POINT1; + } + if ((ChannelLayout & AV_CH_LAYOUT_2_2) == AV_CH_LAYOUT_2_2) { + Channel |= _2_2; + } + if ((ChannelLayout & AV_CH_LAYOUT_QUAD) == AV_CH_LAYOUT_QUAD) { + Channel |= QUAD; + } + if ((ChannelLayout & AV_CH_LAYOUT_5POINT0) == AV_CH_LAYOUT_5POINT0) { + Channel |= _5POINT0; + } + if ((ChannelLayout & AV_CH_LAYOUT_5POINT1) == AV_CH_LAYOUT_5POINT1) { + Channel |= _5POINT1; + } + if ((ChannelLayout & AV_CH_LAYOUT_5POINT0_BACK) == + AV_CH_LAYOUT_5POINT0_BACK) { + Channel |= _5POINT0_BACK; + } + if ((ChannelLayout & AV_CH_LAYOUT_5POINT1_BACK) == + AV_CH_LAYOUT_5POINT1_BACK) { + Channel |= _5POINT1_BACK; + } + if ((ChannelLayout & AV_CH_LAYOUT_6POINT0) == AV_CH_LAYOUT_6POINT0) { + Channel |= _6POINT0; + } + if ((ChannelLayout & AV_CH_LAYOUT_6POINT0_FRONT) == + AV_CH_LAYOUT_6POINT0_FRONT) { + Channel |= _6POINT0_FRONT; + } + if ((ChannelLayout & AV_CH_LAYOUT_HEXAGONAL) == AV_CH_LAYOUT_HEXAGONAL) { + Channel |= HEXAGONAL; + } + if ((ChannelLayout & AV_CH_LAYOUT_6POINT1) == AV_CH_LAYOUT_6POINT1) { + Channel |= _6POINT1; + } + if ((ChannelLayout & AV_CH_LAYOUT_6POINT1_BACK) == + AV_CH_LAYOUT_6POINT1_BACK) { + Channel |= _6POINT1_BACK; + } + if ((ChannelLayout & AV_CH_LAYOUT_6POINT1_FRONT) == + AV_CH_LAYOUT_6POINT1_FRONT) { + Channel |= _6POINT1_FRONT; + } + if ((ChannelLayout & AV_CH_LAYOUT_7POINT0) == AV_CH_LAYOUT_7POINT0) { + Channel |= _7POINT0; + } + if ((ChannelLayout & AV_CH_LAYOUT_7POINT0_FRONT) == + AV_CH_LAYOUT_7POINT0_FRONT) { + Channel |= _7POINT0_FRONT; + } + if ((ChannelLayout & AV_CH_LAYOUT_7POINT1) == AV_CH_LAYOUT_7POINT1) { + Channel |= _7POINT1; + } + if ((ChannelLayout & AV_CH_LAYOUT_7POINT1_WIDE) == + AV_CH_LAYOUT_7POINT1_WIDE) { + Channel |= _7POINT1_WIDE; + } + if ((ChannelLayout & AV_CH_LAYOUT_7POINT1_WIDE_BACK) == + AV_CH_LAYOUT_7POINT1_WIDE_BACK) { + Channel |= _7POINT1_WIDE_BACK; + } + if ((ChannelLayout & AV_CH_LAYOUT_OCTAGONAL) == AV_CH_LAYOUT_OCTAGONAL) { + Channel |= OCTAGONAL; + } + if ((ChannelLayout & AV_CH_LAYOUT_HEXADECAGONAL) == + AV_CH_LAYOUT_HEXADECAGONAL) { + Channel |= HEXADECAGONAL; + } + if ((ChannelLayout & AV_CH_LAYOUT_STEREO_DOWNMIX) == + AV_CH_LAYOUT_STEREO_DOWNMIX) { + Channel |= STEREO_DOWNMIX; + } + return Channel; + } +}; + +class SWRFilterType { +public: + uint32_t fromSwrFilterType(SwrFilterType FilterType) { + switch (FilterType) { + case SWR_FILTER_TYPE_CUBIC: + return 1; + case SWR_FILTER_TYPE_BLACKMAN_NUTTALL: + return 2; + case SWR_FILTER_TYPE_KAISER: + return 3; + default: + return 1; + } + } + + SwrFilterType intoSwrFilterType(uint32_t FilterID) { + switch (FilterID) { + case 1: + return SWR_FILTER_TYPE_CUBIC; + case 2: + return SWR_FILTER_TYPE_BLACKMAN_NUTTALL; + case 3: + return SWR_FILTER_TYPE_KAISER; + default: + return SWR_FILTER_TYPE_CUBIC; + } + } +}; + +class SWREngine { +public: + SwrEngine intoSwrEngine(uint32_t EngineId) { + switch (EngineId) { + case 1: + return SWR_ENGINE_SWR; + case 2: + return SWR_ENGINE_SOXR; + default: + return SWR_ENGINE_SWR; + } + } + + uint32_t fromSwrEngine(SwrEngine Engine) { + switch (Engine) { + case SWR_ENGINE_SWR: + return 1; + case SWR_ENGINE_SOXR: + return 2; + case SWR_ENGINE_NB: + return 3; + default: + return SWR_ENGINE_SWR; + } + } +}; + +class SWRDitherType { +public: + SwrDitherType intoSwrDitherType(uint32_t SwrDitherId) { + switch (SwrDitherId) { + case 0: + return SWR_DITHER_NONE; + case 1: + return SWR_DITHER_RECTANGULAR; + case 2: + return SWR_DITHER_TRIANGULAR; + case 3: + return SWR_DITHER_TRIANGULAR_HIGHPASS; + case 64: + return SWR_DITHER_NS; + case 4: + return SWR_DITHER_NS_LIPSHITZ; + case 5: + return SWR_DITHER_NS_F_WEIGHTED; + case 6: + return SWR_DITHER_NS_MODIFIED_E_WEIGHTED; + case 7: + return SWR_DITHER_NS_IMPROVED_E_WEIGHTED; + case 8: + return SWR_DITHER_NS_SHIBATA; + case 9: + return SWR_DITHER_NS_LOW_SHIBATA; + case 10: + return SWR_DITHER_NS_HIGH_SHIBATA; + case 11: + return SWR_DITHER_NB; + default: + return SWR_DITHER_NONE; + } + } + + uint32_t fromSwrDitherType(SwrDitherType SwrDitherType) { + switch (SwrDitherType) { + case SWR_DITHER_NONE: + return 0; + case SWR_DITHER_RECTANGULAR: + return 1; + case SWR_DITHER_TRIANGULAR: + return 2; + case SWR_DITHER_TRIANGULAR_HIGHPASS: + return 3; + case SWR_DITHER_NS: + return 64; + case SWR_DITHER_NS_LIPSHITZ: + return 4; + case SWR_DITHER_NS_F_WEIGHTED: + return 5; + case SWR_DITHER_NS_MODIFIED_E_WEIGHTED: + return 6; + case SWR_DITHER_NS_IMPROVED_E_WEIGHTED: + return 7; + case SWR_DITHER_NS_SHIBATA: + return 8; + case SWR_DITHER_NS_LOW_SHIBATA: + return 9; + case SWR_DITHER_NS_HIGH_SHIBATA: + return 10; + case SWR_DITHER_NB: + return 11; + default: + return 0; + } + } +}; + +class ChromaLocation { +public: + static AVChromaLocation intoAVChromaLocation(int32_t ChromaLocationId) { + switch (ChromaLocationId) { + case 0: + return AVCHROMA_LOC_UNSPECIFIED; + case 1: + return AVCHROMA_LOC_LEFT; + case 2: + return AVCHROMA_LOC_CENTER; + case 3: + return AVCHROMA_LOC_TOPLEFT; + case 4: + return AVCHROMA_LOC_TOP; + case 5: + return AVCHROMA_LOC_BOTTOMLEFT; + case 6: + return AVCHROMA_LOC_BOTTOM; + default: + return AVCHROMA_LOC_UNSPECIFIED; + } + } + + static int32_t fromAVChromaLocation(AVChromaLocation ChromaLocation) { + switch (ChromaLocation) { + case AVCHROMA_LOC_UNSPECIFIED: + return 0; + case AVCHROMA_LOC_LEFT: + return 1; + case AVCHROMA_LOC_CENTER: + return 2; + case AVCHROMA_LOC_TOPLEFT: + return 3; + case AVCHROMA_LOC_TOP: + return 4; + case AVCHROMA_LOC_BOTTOMLEFT: + return 5; + case AVCHROMA_LOC_BOTTOM: + return 6; + default: + return 0; + } + } +}; + +class Rounding { + +public: + static AVRounding intoAVRounding(int32_t RoundingId) { + switch (RoundingId) { + case 0: + return AV_ROUND_ZERO; + case 1: + return AV_ROUND_INF; + case 2: + return AV_ROUND_DOWN; + case 3: + return AV_ROUND_UP; + case 4: + return AV_ROUND_NEAR_INF; + case 5: + return AV_ROUND_PASS_MINMAX; + default: + return AV_ROUND_ZERO; + } + } + + static int32_t fromAVRounding(AVRounding Rounding) { + switch (Rounding) { + case AV_ROUND_ZERO: + return 0; + case AV_ROUND_INF: + return 1; + case AV_ROUND_DOWN: + return 2; + case AV_ROUND_UP: + return 3; + case AV_ROUND_NEAR_INF: + return 4; + case AV_ROUND_PASS_MINMAX: + return 5; + default: + return 0; + } + } +}; + +class OptionType { + +public: + static AVOptionType intoAVOptionType(int32_t RoundingId) { + switch (RoundingId) { + case 0: + return AV_OPT_TYPE_FLAGS; + case 1: + return AV_OPT_TYPE_INT; + case 2: + return AV_OPT_TYPE_INT64; + case 3: + return AV_OPT_TYPE_DOUBLE; + case 4: + return AV_OPT_TYPE_FLOAT; + case 5: + return AV_OPT_TYPE_STRING; + case 6: + return AV_OPT_TYPE_RATIONAL; + case 7: + return AV_OPT_TYPE_BINARY; + case 8: + return AV_OPT_TYPE_DICT; + case 9: + return AV_OPT_TYPE_CONST; + case 10: + return AV_OPT_TYPE_IMAGE_SIZE; + case 11: + return AV_OPT_TYPE_PIXEL_FMT; + case 12: + return AV_OPT_TYPE_SAMPLE_FMT; + case 13: + return AV_OPT_TYPE_VIDEO_RATE; + case 14: + return AV_OPT_TYPE_DURATION; + case 15: + return AV_OPT_TYPE_COLOR; + case 17: + return AV_OPT_TYPE_UINT64; + case 18: + return AV_OPT_TYPE_BOOL; + case 19: + return AV_OPT_TYPE_CHLAYOUT; + default: + return AV_OPT_TYPE_FLAGS; + } + } + + static int32_t fromAVOptionType(AVOptionType OptionType) { + switch (OptionType) { + case AV_OPT_TYPE_FLAGS: + return 0; + case AV_OPT_TYPE_INT: + return 1; + case AV_OPT_TYPE_INT64: + return 2; + case AV_OPT_TYPE_DOUBLE: + return 3; + case AV_OPT_TYPE_FLOAT: + return 4; + case AV_OPT_TYPE_STRING: + return 5; + case AV_OPT_TYPE_RATIONAL: + return 6; + case AV_OPT_TYPE_BINARY: + return 7; + case AV_OPT_TYPE_DICT: + return 8; + case AV_OPT_TYPE_CONST: + return 9; + case AV_OPT_TYPE_IMAGE_SIZE: + return 10; + case AV_OPT_TYPE_PIXEL_FMT: + return 11; + case AV_OPT_TYPE_SAMPLE_FMT: + return 12; + case AV_OPT_TYPE_VIDEO_RATE: + return 13; + case AV_OPT_TYPE_DURATION: + return 14; + case AV_OPT_TYPE_COLOR: + return 15; + case AV_OPT_TYPE_UINT64: + return 17; + case AV_OPT_TYPE_BOOL: + return 18; + case AV_OPT_TYPE_CHLAYOUT: + return 19; + default: + return 0; + } + } +}; + +class PictureType { +public: + static AVPictureType intoAVPictureType(int32_t PictureId) { + switch (PictureId) { + case 0: + return AV_PICTURE_TYPE_NONE; + case 1: + return AV_PICTURE_TYPE_I; + case 2: + return AV_PICTURE_TYPE_P; + case 3: + return AV_PICTURE_TYPE_B; + case 4: + return AV_PICTURE_TYPE_S; + case 5: + return AV_PICTURE_TYPE_SI; + case 6: + return AV_PICTURE_TYPE_SP; + case 7: + return AV_PICTURE_TYPE_BI; + default: + return AV_PICTURE_TYPE_NONE; + } + }; + + static int32_t fromAVPictureType(AVPictureType PictureType) { + switch (PictureType) { + case AV_PICTURE_TYPE_NONE: + return 0; + case AV_PICTURE_TYPE_I: + return 1; + case AV_PICTURE_TYPE_P: + return 2; + case AV_PICTURE_TYPE_B: + return 3; + case AV_PICTURE_TYPE_S: + return 4; + case AV_PICTURE_TYPE_SI: + return 5; + case AV_PICTURE_TYPE_SP: + return 6; + case AV_PICTURE_TYPE_BI: + return 7; + default: + return 0; + } + } +}; + +// Direct mapping in rust. Not required. Can be used for decoupling (Clean +// Code). +// +// class ColorTransferCharacteristic { +// +// static AVColorTransferCharacteristic +// intoColorTransferCharacteristic(uint32_t ColorTransferCharacteristicId) { +// switch (ColorTransferCharacteristicId) { +// case 0: +// return AVCOL_TRC_RESERVED0; +// case 1: +// return AVCOL_TRC_BT709; +// case 2: +// return AVCOL_TRC_UNSPECIFIED; +// case 3: +// return AVCOL_TRC_RESERVED; +// case 4: +// return AVCOL_TRC_GAMMA22; +// case 5: +// return AVCOL_TRC_GAMMA28; +// case 6: +// return AVCOL_TRC_SMPTE170M; +// case 7: +// return AVCOL_TRC_SMPTE240M; +// case 8: +// return AVCOL_TRC_LINEAR; +// case 9: +// return AVCOL_TRC_LOG; +// case 10: +// return AVCOL_TRC_LOG_SQRT; +// case 11: +// return AVCOL_TRC_IEC61966_2_4; +// case 12: +// return AVCOL_TRC_BT1361_ECG; +// case 13: +// return AVCOL_TRC_IEC61966_2_1; +// case 14: +// return AVCOL_TRC_BT2020_10; +// case 15: +// return AVCOL_TRC_BT2020_12; +// case 16: +// return AVCOL_TRC_SMPTE2084; +// case 17: +// return AVCOL_TRC_SMPTE428; +// case 18: +// return AVCOL_TRC_ARIB_STD_B67; +// case 19: +// return AVCOL_TRC_NB; +// default: +// return AVCOL_TRC_RESERVED0; +// } +// }; +// +// static uint32_t +// fromColorTransferCharacteristic(uint32_t ColorTransferCharacteristic) { +// switch (ColorTransferCharacteristic) { +// case AVCOL_TRC_RESERVED0: +// return 0; +// case AVCOL_TRC_BT709: +// return 1; +// case AVCOL_TRC_UNSPECIFIED: +// return 2; +// case AVCOL_TRC_RESERVED: +// return 3; +// case AVCOL_TRC_GAMMA22: +// return 4; +// case AVCOL_TRC_GAMMA28: +// return 5; +// case AVCOL_TRC_SMPTE170M: +// return 6; +// case AVCOL_TRC_SMPTE240M: +// return 7; +// case AVCOL_TRC_LINEAR: +// return 8; +// case AVCOL_TRC_LOG: +// return 9; +// case AVCOL_TRC_LOG_SQRT: +// return 10; +// case AVCOL_TRC_IEC61966_2_4: +// return 11; +// case AVCOL_TRC_BT1361_ECG: +// return 12; +// case AVCOL_TRC_IEC61966_2_1: +// return 13; +// case AVCOL_TRC_BT2020_10: +// return 14; +// case AVCOL_TRC_BT2020_12: +// return 15; +// case AVCOL_TRC_SMPTE2084: +// return 16; +// case AVCOL_TRC_SMPTE428: +// return 17; +// case AVCOL_TRC_ARIB_STD_B67: +// return 18; +// case AVCOL_TRC_NB: +// return 19; +// default: +// return 0; +// } +// }; +//}; + +// We can keep or remove the binding. +class ColorSpace { + +public: + static AVColorSpace intoAVColorSpace(int32_t ColorSpaceId) { + + switch (ColorSpaceId) { + case 0: + return AVCOL_SPC_RGB; + case 1: + return AVCOL_SPC_BT709; + case 2: + return AVCOL_SPC_UNSPECIFIED; + case 3: + return AVCOL_SPC_RESERVED; + case 4: + return AVCOL_SPC_FCC; + case 5: + return AVCOL_SPC_BT470BG; + case 6: + return AVCOL_SPC_SMPTE170M; + case 7: + return AVCOL_SPC_SMPTE240M; + case 8: + return AVCOL_SPC_YCGCO; + case 9: + return AVCOL_SPC_BT2020_NCL; + case 10: + return AVCOL_SPC_BT2020_CL; + case 11: + return AVCOL_SPC_SMPTE2085; + case 12: + return AVCOL_SPC_CHROMA_DERIVED_NCL; + case 13: + return AVCOL_SPC_CHROMA_DERIVED_CL; + case 14: + return AVCOL_SPC_ICTCP; + default: + return AVCOL_SPC_RGB; + } + }; + + static int32_t fromAVColorSpace(AVColorSpace ColorSpace) { + + switch (ColorSpace) { + case AVCOL_SPC_RGB: + return 0; + case AVCOL_SPC_BT709: + return 1; + case AVCOL_SPC_UNSPECIFIED: + return 2; + case AVCOL_SPC_RESERVED: + return 3; + case AVCOL_SPC_FCC: + return 4; + case AVCOL_SPC_BT470BG: + return 5; + case AVCOL_SPC_SMPTE170M: + return 6; + case AVCOL_SPC_SMPTE240M: + return 7; + case AVCOL_SPC_YCGCO: + return 8; + case AVCOL_SPC_BT2020_NCL: + return 9; + case AVCOL_SPC_BT2020_CL: + return 10; + case AVCOL_SPC_SMPTE2085: + return 11; + case AVCOL_SPC_CHROMA_DERIVED_NCL: + return 12; + case AVCOL_SPC_CHROMA_DERIVED_CL: + return 13; + case AVCOL_SPC_ICTCP: + return 14; + default: + return 0; + } + }; +}; + +class FieldOrder { +public: + static AVFieldOrder intoAVFieldOrder(int32_t FieldOrderId) { + switch (FieldOrderId) { + case 0: + return AV_FIELD_UNKNOWN; + case 1: + return AV_FIELD_PROGRESSIVE; + case 2: + return AV_FIELD_TT; + case 3: + return AV_FIELD_BB; + case 4: + return AV_FIELD_TB; + case 5: + return AV_FIELD_BT; + default: + return AV_FIELD_UNKNOWN; + } + } + + static int32_t fromAVFieldOrder(AVFieldOrder FieldOrder) { + switch (FieldOrder) { + case AV_FIELD_UNKNOWN: + return 0; + case AV_FIELD_PROGRESSIVE: + return 1; + case AV_FIELD_TT: + return 2; + case AV_FIELD_BB: + return 3; + case AV_FIELD_TB: + return 4; + case AV_FIELD_BT: + return 5; + default: + return 0; + } + } +}; + +class ColorPrimaries { + +public: + static AVColorPrimaries intoAVColorPrimaries(int32_t ColorPrimariesId) { + switch (ColorPrimariesId) { + case 0: + return AVCOL_PRI_RESERVED0; + case 1: + return AVCOL_PRI_BT709; + case 2: + return AVCOL_PRI_UNSPECIFIED; + case 3: + return AVCOL_PRI_RESERVED; + case 4: + return AVCOL_PRI_BT470M; + case 5: + return AVCOL_PRI_BT470BG; + case 6: + return AVCOL_PRI_SMPTE170M; + case 7: + return AVCOL_PRI_SMPTE240M; + case 8: + return AVCOL_PRI_FILM; + case 9: + return AVCOL_PRI_BT2020; + case 10: + return AVCOL_PRI_SMPTE428; + case 11: + return AVCOL_PRI_SMPTE431; + case 12: + return AVCOL_PRI_SMPTE432; + case 13: + return AVCOL_PRI_JEDEC_P22; + case 14: + return AVCOL_PRI_EBU3213; + default: + return AVCOL_PRI_RESERVED0; + } + }; + + static int32_t fromAVColorPrimaries(AVColorPrimaries ColorPrimaries) { + switch (ColorPrimaries) { + case AVCOL_PRI_RESERVED0: + return 0; + case AVCOL_PRI_BT709: + return 1; + case AVCOL_PRI_UNSPECIFIED: + return 2; + case AVCOL_PRI_RESERVED: + return 3; + case AVCOL_PRI_BT470M: + return 4; + case AVCOL_PRI_BT470BG: + return 5; + case AVCOL_PRI_SMPTE170M: + return 6; + case AVCOL_PRI_SMPTE240M: + return 7; + case AVCOL_PRI_FILM: + return 8; + case AVCOL_PRI_BT2020: + return 9; + case AVCOL_PRI_SMPTE428: + return 10; + case AVCOL_PRI_SMPTE431: + return 11; + case AVCOL_PRI_SMPTE432: + return 12; + // #[cfg(not(feature = "ffmpeg_4_3"))] + // case AVCOL_PRI_JEDEC_P22: + // return 13; + case AVCOL_PRI_EBU3213: + return 14; + default: + return 0; + } + }; +}; + +} // namespace FFmpegUtils +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_base.h b/plugins/wasmedge_ffmpeg/ffmpeg_base.h new file mode 100644 index 00000000..faac0e3d --- /dev/null +++ b/plugins/wasmedge_ffmpeg/ffmpeg_base.h @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/callingframe.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +template class HostFunction : public Runtime::HostFunction { +public: + HostFunction(std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp new file mode 100644 index 00000000..14ded24c --- /dev/null +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avcodec/module.h" +#include "avdevice/module.h" +#include "avfilter/module.h" +#include "avformat/module.h" +#include "avutil/module.h" +#include "swresample/module.h" +#include "swscale/module.h" + +#include "ffmpeg_env.h" + +namespace WasmEdge { +namespace Host { +namespace { + +Runtime::Instance::ModuleInstance * +createAVCodec(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVDevice(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVDevice::WasmEdgeFFmpegAVDeviceModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVFilter(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVFilter::WasmEdgeFFmpegAVFilterModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVFormat(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVFormat::WasmEdgeFFmpegAVFormatModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVUtil(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createSWResample(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::SWResample::WasmEdgeFFmpegSWResampleModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createSWScale(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_ffmpeg", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 0, 0, 1}, + .ModuleCount = 7, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_ffmpeg_avcodec", + .Description = "encoding/decoding library", + .Create = createAVCodec, + }, + { + .Name = "wasmedge_ffmpeg_avdevice", + .Description = "special devices muxing/demuxing library ", + .Create = createAVDevice, + }, + { + .Name = "wasmedge_ffmpeg_avfilter", + .Description = "graph-based frame editing library", + .Create = createAVFilter, + }, + { + .Name = "wasmedge_ffmpeg_avformat", + .Description = "I/O and muxing/demuxing library", + .Create = createAVFormat, + }, + { + .Name = "wasmedge_ffmpeg_avutil", + .Description = "utils utility library", + .Create = createAVUtil, + }, + { + .Name = "wasmedge_ffmpeg_swresample", + .Description = "audio resampling, format conversion and mixing", + .Create = createSWResample, + }, + { + .Name = "wasmedge_ffmpeg_swscale", + .Description = "color conversion and scaling library", + .Create = createSWScale, + }}, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace + +std::weak_ptr + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::Instance = + std::make_shared(); + +std::shared_mutex WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::Mutex; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.h b/plugins/wasmedge_ffmpeg/ffmpeg_env.h new file mode 100644 index 00000000..c2ad46b7 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.h @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "bindings.h" +#include "plugin/plugin.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +class WasmEdgeFFmpegEnv { +public: + // Singleton + static std::shared_ptr getInstance() noexcept { + std::unique_lock Lock(Mutex); + std::shared_ptr EnvPtr = Instance.lock(); + if (!EnvPtr) { + EnvPtr.reset(new WasmEdgeFFmpegEnv()); + Instance = EnvPtr; + } + return EnvPtr; + } + + // Avoid copy constructor and overloading functions. + WasmEdgeFFmpegEnv(const WasmEdgeFFmpegEnv &) = delete; + void operator=(const WasmEdgeFFmpegEnv &) = delete; + + void alloc(void *Data, uint32_t *DataPtr) { + FfmpegPtrMap[FfmpegPtrAllocateKey++] = Data; + *DataPtr = FfmpegPtrAllocateKey - 1; + } + + void *fetchData(const size_t Index) { + if (Index >= FfmpegPtrAllocateKey) { + return nullptr; + } + // Check this condition. + if (FfmpegPtrMap[Index] == nullptr) { + return nullptr; + } + + return FfmpegPtrMap[Index]; + } + + void dealloc(size_t Index) { + + if (Index >= FfmpegPtrAllocateKey) { + return; + } + + FfmpegPtrMap.erase(Index); + } + + WasmEdgeFFmpegEnv() noexcept {} + +private: + // Using zero as NULL Value. + uint32_t FfmpegPtrAllocateKey = 1; + // Can update this to uint64_t to get more memory. + std::map FfmpegPtrMap; + static std::weak_ptr Instance; + static std::shared_mutex Mutex; +}; + +// Utils functions. +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-FFmpeg] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define FFMPEG_PTR_FETCH(StructPtr, FFmpegStructId, Type) \ + Type *StructPtr = nullptr; \ + if (FFmpegStructId != 0) \ + StructPtr = static_cast(Env.get()->fetchData(FFmpegStructId)); + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-FFmpeg] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define FFMPEG_PTR_STORE(StructPtr, FFmpegStructId) \ + Env.get()->alloc(StructPtr, FFmpegStructId); + +#define FFMPEG_PTR_DELETE(FFmpegStructId) Env.get()->dealloc(FFmpegStructId); + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointerOrNull(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-FFmpeg] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +// Starting from 200 because, posix codes take values till 131. +// Hence using 200. +enum class ErrNo : int32_t { + Success = 0, // No error occurred. + MissingMemory = -201, // Caller module is missing a memory export. + NullStructId = -202, // Rust Sdk Passes null id. + InternalError = -203, + UnImplemented = -204 // Unimplemented funcs. +}; + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/module.cpp b/plugins/wasmedge_ffmpeg/swresample/module.cpp new file mode 100644 index 00000000..5b5f9867 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/module.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "swresample_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +WasmEdgeFFmpegSWResampleModule::WasmEdgeFFmpegSWResampleModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_swresample") { + addHostFunc("wasmedge_ffmpeg_swresample_swresample_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_get_delay", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_init", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_alloc_set_opts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_av_opt_set_dict", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_convert_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_license", + std::make_unique(Env)); +} + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/module.h b/plugins/wasmedge_ffmpeg/swresample/module.h new file mode 100644 index 00000000..00d4c839 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +class WasmEdgeFFmpegSWResampleModule + : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegSWResampleModule(std::shared_ptr Env); +}; + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp new file mode 100644 index 00000000..07244499 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "swresample_func.h" + +extern "C" { +#include "libavutil/avutil.h" +#include "libavutil/opt.h" +#include "libswresample/swresample.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +Expect SWResampleVersion::body(const Runtime::CallingFrame &) { + return swresample_version(); +} + +Expect SWRGetDelay::body(const Runtime::CallingFrame &, + uint32_t SWRContextId, int64_t Base) { + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + return swr_get_delay(SWRContext, Base); +} + +Expect SWRInit::body(const Runtime::CallingFrame &, + uint32_t SWRContextId) { + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + return swr_init(SWRContext); +} + +Expect +SWRAllocSetOpts::body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, + uint32_t SWRContextId, uint64_t OutChLayoutId, + uint32_t OutSampleFmtId, int32_t OutSampleRate, + uint64_t InChLayoutId, uint32_t InSampleFmtId, + int32_t InSampleRate, int32_t LogOffset) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwrCtxId, MemInst, uint32_t, SwrCtxPtr, "") + FFMPEG_PTR_FETCH(CurrSwrCtx, *SwrCtxId, SwrContext); + FFMPEG_PTR_FETCH(ExistSWRContext, SWRContextId, SwrContext); + + uint64_t const OutChLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(OutChLayoutId); + AVSampleFormat const OutSampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(OutSampleFmtId); + uint64_t const InChLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(InChLayoutId); + AVSampleFormat const InSampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(InSampleFmtId); + + AVChannelLayout AVOutChLayout; + av_channel_layout_from_mask(&AVOutChLayout, OutChLayout); + + AVChannelLayout AVInChLayout; + av_channel_layout_from_mask(&AVInChLayout, InChLayout); + + swr_alloc_set_opts2(&ExistSWRContext, &AVOutChLayout, OutSampleFmt, + OutSampleRate, &AVInChLayout, InSampleFmt, InSampleRate, + LogOffset, + nullptr); // Always being used as null in rust sdk. + CurrSwrCtx = ExistSWRContext; + + av_channel_layout_uninit(&AVOutChLayout); + av_channel_layout_uninit(&AVInChLayout); + + FFMPEG_PTR_STORE(CurrSwrCtx, SwrCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVOptSetDict::body(const Runtime::CallingFrame &, + uint32_t SWRContextId, uint32_t DictId) { + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + return av_opt_set_dict(SWRContext, AvDictionary); +} + +Expect SWRConvertFrame::body(const Runtime::CallingFrame &, + uint32_t SWRContextId, + uint32_t FrameOutputId, + uint32_t FrameInputId) { + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + FFMPEG_PTR_FETCH(OuputFrame, FrameOutputId, AVFrame); + FFMPEG_PTR_FETCH(InputFrame, FrameInputId, AVFrame); + + return swr_convert_frame(SWRContext, OuputFrame, InputFrame); +} + +Expect SWRFree::body(const Runtime::CallingFrame &, + uint32_t SWRContextId) { + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + swr_close(SWRContext); + FFMPEG_PTR_DELETE(SWRContextId); + return static_cast(ErrNo::Success); +} + +Expect +SWResampleConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = swresample_configuration(); + return strlen(Config); +} + +Expect +SWResampleConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, uint32_t ConfigLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = swresample_configuration(); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect SWResampleLicenseLength::body(const Runtime::CallingFrame &) { + const char *License = swresample_license(); + return strlen(License); +} + +Expect SWResampleLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = swresample_license(); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h new file mode 100644 index 00000000..61104d70 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +class SWResampleVersion : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SWRGetDelay : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId, int64_t Base); +}; + +class SWRInit : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId); +}; + +class SWRAllocSetOpts : public HostFunction { +public: + using HostFunction::HostFunction; + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, + uint32_t SWRContextId, uint64_t OutChLayout, + uint32_t OutSampleFmtId, int32_t OutSampleRate, + uint64_t InChLayout, uint32_t InSampleFmtId, + int32_t InSampleRate, int32_t LogOffset); +}; + +class AVOptSetDict : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId, uint32_t DictId); +}; + +class SWRConvertFrame : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId, uint32_t FrameOutputId, + uint32_t FrameInputId); +}; + +class SWRFree : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId); +}; + +class SWResampleConfigurationLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SWResampleConfiguration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class SWResampleLicenseLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SWResampleLicense : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/module.cpp b/plugins/wasmedge_ffmpeg/swscale/module.cpp new file mode 100644 index 00000000..da84ee06 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/module.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "swscale_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +WasmEdgeFFmpegSWScaleModule::WasmEdgeFFmpegSWScaleModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_swscale") { + addHostFunc("wasmedge_ffmpeg_swscale_swscale_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_license", + std::make_unique(Env)); + + // SwsContext + addHostFunc("wasmedge_ffmpeg_swscale_sws_getContext", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_freeContext", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_scale", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getCachedContext", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_isSupportedInput", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_isSupportedOutput", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_isSupportedEndiannessConversion", + std::make_unique(Env)); + + // SwsFilter + addHostFunc("wasmedge_ffmpeg_swscale_sws_getDefaultFilter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getLumaH", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getLumaV", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getChromaH", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getChromaV", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_freeFilter", + std::make_unique(Env)); + + // SwsVector + addHostFunc("wasmedge_ffmpeg_swscale_sws_allocVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getGaussianVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_scaleVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_normalizeVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getCoeffVecLength", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getCoeff", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_freeVec", + std::make_unique(Env)); +} + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/module.h b/plugins/wasmedge_ffmpeg/swscale/module.h new file mode 100644 index 00000000..e9ca104d --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/module.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +class WasmEdgeFFmpegSWScaleModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegSWScaleModule(std::shared_ptr Env); +}; + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp new file mode 100644 index 00000000..90107b8b --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "swscale_func.h" + +extern "C" { +#include "libavutil/frame.h" +#include "libswscale/swscale.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +Expect +SwsGetContext::body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxPtr, + uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, + int32_t Flags, uint32_t SrcFilterId, uint32_t DesFilterId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsCtxId, MemInst, uint32_t, SwsCtxPtr, + "Failed when accessing the return SWSContext Memory"sv) + + FFMPEG_PTR_FETCH(SwsCtx, *SwsCtxId, SwsContext) + FFMPEG_PTR_FETCH(SrcSwsFilter, SrcFilterId, SwsFilter) + FFMPEG_PTR_FETCH(DesSwsFilter, DesFilterId, SwsFilter) + + AVPixelFormat const SrcPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(SrcPixFormatId); + AVPixelFormat const DestPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(DesPixFormatId); + SwsCtx = sws_getContext(SrcW, SrcH, SrcPixelFormat, DesW, DesH, + DestPixelFormat, Flags, SrcSwsFilter, DesSwsFilter, + nullptr); // Not using param anywhere in Rust SDK. + if (SwsCtx == nullptr) { + return static_cast(ErrNo::InternalError); + } + FFMPEG_PTR_STORE(SwsCtx, SwsCtxId); + return static_cast(ErrNo::Success); +} + +Expect SwsFreeContext::body(const Runtime::CallingFrame &, + uint32_t SwsCtxId) { + FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext) + sws_freeContext(SwsCtx); + FFMPEG_PTR_DELETE(SwsCtxId); + return static_cast(ErrNo::Success); +} + +Expect SwsScale::body(const Runtime::CallingFrame &, uint32_t SwsCtxId, + uint32_t InputFrameId, int32_t SrcSliceY, + int32_t SrcSliceH, uint32_t OutputFrameId) { + FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext); + FFMPEG_PTR_FETCH(InputFrame, InputFrameId, AVFrame); + FFMPEG_PTR_FETCH(OutputFrame, OutputFrameId, AVFrame); + return sws_scale(SwsCtx, InputFrame->data, InputFrame->linesize, SrcSliceY, + SrcSliceH, OutputFrame->data, OutputFrame->linesize); +} + +Expect SwsGetCachedContext::body( + const Runtime::CallingFrame &Frame, uint32_t SwsCachedCtxPtr, + uint32_t SwsCtxId, uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, int32_t Flags, + uint32_t SrcFilterId, uint32_t DesFilterId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsCachedCtxId, MemInst, uint32_t, SwsCachedCtxPtr, "") + + FFMPEG_PTR_FETCH(SwsCachedCtx, *SwsCachedCtxId, SwsContext); + FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext); + FFMPEG_PTR_FETCH(SrcSwsFilter, SrcFilterId, SwsFilter) + FFMPEG_PTR_FETCH(DesSwsFilter, DesFilterId, SwsFilter) + + AVPixelFormat const SrcPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(SrcPixFormatId); + AVPixelFormat const DestPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(DesPixFormatId); + SwsCachedCtx = sws_getCachedContext(SwsCtx, SrcW, SrcH, SrcPixelFormat, DesW, + DesH, DestPixelFormat, Flags, + SrcSwsFilter, DesSwsFilter, nullptr); + if (SwsCachedCtx == nullptr) { + return static_cast(ErrNo::InternalError); + } + + FFMPEG_PTR_STORE(SwsCachedCtx, SwsCachedCtxId); + return static_cast(ErrNo::Success); +} + +Expect SwsIsSupportedInput::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return sws_isSupportedInput(PixelFormat); +} + +Expect SwsIsSupportedOutput::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return sws_isSupportedOutput(PixelFormat); +} + +Expect +SwsIsSupportedEndiannessConversion::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return sws_isSupportedEndiannessConversion(PixelFormat); +} + +Expect SwsGetDefaultFilter::body( + const Runtime::CallingFrame &Frame, uint32_t SwsFilterPtr, float LumaGBlur, + float ChromaGBlur, float LumaSharpen, float ChromaSharpen, + float ChromaHShift, float ChromaVShift, int32_t Verbose) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsFilterId, MemInst, uint32_t, SwsFilterPtr, "") + + SwsFilter *Filter = + sws_getDefaultFilter(LumaGBlur, ChromaGBlur, LumaSharpen, ChromaSharpen, + ChromaHShift, ChromaVShift, Verbose); + if (Filter == nullptr) { + return static_cast(ErrNo::InternalError); + } + FFMPEG_PTR_STORE(Filter, SwsFilterId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetLumaH::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, uint32_t SwsVectorPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->lumH; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetLumaV::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, uint32_t SwsVectorPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->lumV; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetChromaH::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, + uint32_t SwsVectorPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->chrH; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetChromaV::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, + uint32_t SwsVectorPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->chrV; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsFreeFilter::body(const Runtime::CallingFrame &, + uint32_t SwsFilterId) { + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + sws_freeFilter(Filter); + FFMPEG_PTR_DELETE(SwsFilterId); + return static_cast(ErrNo::Success); +} + +Expect SwsAllocVec::body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, int32_t Length) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + + SwsVector *Vector = sws_allocVec(Length); + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetGaussianVec::body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, double Variance, + double Quality) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + + SwsVector *Vector = sws_getGaussianVec(Variance, Quality); + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsScaleVec::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId, double Scalar) { + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + sws_scaleVec(Vector, Scalar); + return static_cast(ErrNo::Success); +} + +Expect SwsNormalizeVec::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId, double Height) { + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + sws_normalizeVec(Vector, Height); + return static_cast(ErrNo::Success); +} + +Expect SwsGetCoeffVecLength::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId) { + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + return Vector->length * + sizeof(double); // Getting the size in uint_8* (Cuz Passing uint8_t* + // array from Rust SDK). +} + +Expect SwsGetCoeff::body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorId, uint32_t CoeffBufPtr, + uint32_t Len) { + MEMINST_CHECK(MemInst, Frame, 0) + MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, CoeffBufPtr, Len, ""); + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + + double *Coeff = Vector->coeff; + std::copy_n(Coeff, Len, Buffer.data()); + return static_cast(ErrNo::Success); +} + +Expect SwsFreeVec::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId) { + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + sws_freeVec(Vector); + FFMPEG_PTR_DELETE(SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwscaleVersion::body(const Runtime::CallingFrame &) { + return swscale_version(); +} + +Expect +SwscaleConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = swscale_configuration(); + return strlen(Config); +} + +Expect SwscaleConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = swscale_configuration(); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect SwscaleLicenseLength::body(const Runtime::CallingFrame &) { + const char *License = swscale_license(); + return strlen(License); +} + +Expect SwscaleLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, uint32_t LicenseLen) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = swscale_license(); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.h b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h new file mode 100644 index 00000000..12edd643 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ffmpeg_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +class SwsGetContext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxPtr, + uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, + int32_t Flags, uint32_t SrcFilterId, + uint32_t DesFilterId); +}; + +class SwsFreeContext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxId); +}; + +class SwsScale : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxId, + uint32_t InputFrameId, int32_t SrcSliceY, + int32_t SrcSliceH, uint32_t OutputFrameId); +}; + +class SwsGetCachedContext : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsCachedCtxPtr, uint32_t SwsCtxPtr, + uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, + int32_t Flags, uint32_t SrcFilterId, + uint32_t DesFilterId); +}; + +class SwsIsSupportedInput : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class SwsIsSupportedOutput : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class SwsIsSupportedEndiannessConversion + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class SwsGetDefaultFilter : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterPtr, float LumaGBlur, + float ChromaGBlur, float LumaSharpen, + float ChromaSharpen, float ChromaHShift, + float ChromaVShift, int32_t Verbose); +}; + +class SwsGetLumaH : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsGetLumaV : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsGetChromaH : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsGetChromaV : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsFreeFilter : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId); +}; + +class SwsAllocVec : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, int32_t Length); +}; + +class SwsGetGaussianVec : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, double Variance, double Quality); +}; + +class SwsScaleVec : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId, + double Scalar); +}; + +class SwsNormalizeVec : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId, + double Height); +}; + +class SwsGetCoeffVecLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t SwsVectorId); +}; + +class SwsGetCoeff : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &, uint32_t SwsVectorId, + uint32_t CoeffBuf, uint32_t Len); +}; + +class SwsFreeVec : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorId); +}; + +class SwscaleVersion : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SwscaleConfigurationLength + : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SwscaleConfiguration : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class SwscaleLicenseLength : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SwscaleLicense : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt new file mode 100644 index 00000000..fd24e7eb --- /dev/null +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeImage + SHARED + image_env.cpp + image_func.cpp + image_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeImage + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeImage + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +# Need stb_image. +wasmedge_setup_stb_image() +target_link_libraries(wasmedgePluginWasmEdgeImage + PUBLIC + wasmedgeDepsSTBImage +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeImage + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeImage + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeImage + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_image/image_base.h b/plugins/wasmedge_image/image_base.h new file mode 100644 index 00000000..47f501b2 --- /dev/null +++ b/plugins/wasmedge_image/image_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "image_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +template class Func : public Runtime::HostFunction { +public: + Func(ImgEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + ImgEnv &Env; +}; + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_env.cpp b/plugins/wasmedge_image/image_env.cpp new file mode 100644 index 00000000..12ab8886 --- /dev/null +++ b/plugins/wasmedge_image/image_env.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "image_env.h" +#include "image_module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeImageModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_image", + .Description = "Image loading plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 13, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_image", + .Description = + "This module contains WasmEdge-Image host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_env.h b/plugins/wasmedge_image/image_env.h new file mode 100644 index 00000000..837b7086 --- /dev/null +++ b/plugins/wasmedge_image/image_env.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + Fail = 1, // Runtime Error. +}; + +enum class DataType : uint32_t { + RGB8 = 0, + BGR8 = 1, + RGB32F = 2, + BGR32F = 3, +}; + +struct ImgEnv {}; + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_func.cpp b/plugins/wasmedge_image/image_func.cpp new file mode 100644 index 00000000..216e8e82 --- /dev/null +++ b/plugins/wasmedge_image/image_func.cpp @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "image_func.h" + +#include "common/span.h" +#include "common/spdlog.h" + +#define STB_IMAGE_IMPLEMENTATION +#include +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#include + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +namespace { + +bool decodeImgToSize(Span Buf, uint32_t W, uint32_t H, + DataType OutType, Span DstBuf) noexcept { + // Specify the target data format. + bool IsRGB = true; + bool IsU8 = true; + switch (OutType) { + case DataType::BGR8: + IsRGB = false; + [[fallthrough]]; + case DataType::RGB8: + break; + case DataType::BGR32F: + IsRGB = false; + [[fallthrough]]; + case DataType::RGB32F: + IsU8 = false; + break; + default: + return false; + } + + // Load and decode the image from buffer. + union RawImagePtr { + uint8_t *U8; + float *F32; + }; + RawImagePtr RawImg; + RawImg.U8 = nullptr; + int IW, IH, IC; + if (IsU8) { + RawImg.U8 = stbi_load_from_memory(Buf.data(), Buf.size(), &IW, &IH, &IC, 3); + } else { + RawImg.F32 = + stbi_loadf_from_memory(Buf.data(), Buf.size(), &IW, &IH, &IC, 3); + } + if (RawImg.U8 == nullptr) { + spdlog::error("[WasmEdge-Image] Load image failed."sv); + return false; + } + + // Resize. + if (unlikely(DstBuf.size() < + W * H * 3 * (IsU8 ? sizeof(uint8_t) : sizeof(float)))) { + spdlog::error("[WasmEdge-Image] Output buffer size {} not enough. "sv + "At least need {} bytes."sv, + DstBuf.size(), + W * H * 3 * (IsU8 ? sizeof(uint8_t) : sizeof(float))); + return false; + } + if (IsU8) { + stbir_resize_uint8_linear(RawImg.U8, IW, IH, 0, DstBuf.data(), + static_cast(W), static_cast(H), 0, + STBIR_RGB); + } else { + stbir_resize_float_linear( + RawImg.F32, IW, IH, 0, reinterpret_cast(DstBuf.data()), + static_cast(W), static_cast(H), 0, STBIR_RGB); + } + + // Handle BGR case. + if (!IsRGB) { + if (IsU8) { + for (uint32_t I = 0; I < W * H; I++) { + std::swap(DstBuf[I * 3], DstBuf[I * 3 + 2]); + } + } else { + auto F32DstBuf = Span(reinterpret_cast(DstBuf.data()), + DstBuf.size() / sizeof(float)); + for (uint32_t I = 0; I < W * H; I++) { + std::swap(F32DstBuf[I * 3], F32DstBuf[I * 3 + 2]); + } + } + } + stbi_image_free(RawImg.U8); + return true; +} + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Image] Memory instance not found."sv); \ + return static_cast(ErrNo::Fail); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Image] "sv Message); \ + return static_cast(ErrNo::Fail); \ + } + +} // namespace + +Expect LoadJPG::body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, + uint32_t OutType, uint32_t OutBufPtr, + uint32_t OutBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input image buffer. + MEM_SPAN_CHECK(ImgBufSpan, MemInst, uint8_t, InImgBufPtr, InImgBufLen, + "Failed when accessing the input image buffer memory."sv) + + // Check the output decoded image buffer. + MEM_SPAN_CHECK(OutBufSpan, MemInst, uint8_t, OutBufPtr, OutBufLen, + "Failed when accessing the output image data buffer memory."sv) + + if (unlikely(!decodeImgToSize(ImgBufSpan, OutImgW, OutImgH, + static_cast(OutType), OutBufSpan))) { + return static_cast(ErrNo::Fail); + } + return static_cast(ErrNo::Success); +} + +Expect LoadPNG::body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, + uint32_t OutType, uint32_t OutBufPtr, + uint32_t OutBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input image buffer. + MEM_SPAN_CHECK(ImgBufSpan, MemInst, uint8_t, InImgBufPtr, InImgBufLen, + "Failed when accessing the input image buffer memory."sv) + + // Check the output decoded image buffer. + MEM_SPAN_CHECK(OutBufSpan, MemInst, uint8_t, OutBufPtr, OutBufLen, + "Failed when accessing the output image data buffer memory."sv) + + if (unlikely(!decodeImgToSize(ImgBufSpan, OutImgW, OutImgH, + static_cast(OutType), OutBufSpan))) { + return static_cast(ErrNo::Fail); + } + return static_cast(ErrNo::Success); +} + +Expect LoadImage::body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, + uint32_t OutType, uint32_t OutBufPtr, + uint32_t OutBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input image buffer. + MEM_SPAN_CHECK(ImgBufSpan, MemInst, uint8_t, InImgBufPtr, InImgBufLen, + "Failed when accessing the input image buffer memory."sv) + + // Check the output decoded image buffer. + MEM_SPAN_CHECK(OutBufSpan, MemInst, uint8_t, OutBufPtr, OutBufLen, + "Failed when accessing the output image data buffer memory."sv) + + if (unlikely(!decodeImgToSize(ImgBufSpan, OutImgW, OutImgH, + static_cast(OutType), OutBufSpan))) { + return static_cast(ErrNo::Fail); + } + return static_cast(ErrNo::Success); +} + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_func.h b/plugins/wasmedge_image/image_func.h new file mode 100644 index 00000000..4b6bc14f --- /dev/null +++ b/plugins/wasmedge_image/image_func.h @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "image_base.h" + +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +class LoadJPG : public Func { +public: + LoadJPG(ImgEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, uint32_t OutType, + uint32_t OutBufPtr, uint32_t OutBufLen); +}; + +class LoadPNG : public Func { +public: + LoadPNG(ImgEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, uint32_t OutType, + uint32_t OutBufPtr, uint32_t OutBufLen); +}; + +class LoadImage : public Func { +public: + LoadImage(ImgEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, uint32_t OutType, + uint32_t OutBufPtr, uint32_t OutBufLen); +}; + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_module.cpp b/plugins/wasmedge_image/image_module.cpp new file mode 100644 index 00000000..65e71479 --- /dev/null +++ b/plugins/wasmedge_image/image_module.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "image_module.h" +#include "image_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeImageModule::WasmEdgeImageModule() + : Runtime::Instance::ModuleInstance("wasmedge_image") { + addHostFunc("load_jpg", std::make_unique(Env)); + addHostFunc("load_png", std::make_unique(Env)); + addHostFunc("load_image", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_module.h b/plugins/wasmedge_image/image_module.h new file mode 100644 index 00000000..8d4f42b4 --- /dev/null +++ b/plugins/wasmedge_image/image_module.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "image_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeImageModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeImageModule(); + ~WasmEdgeImageModule() = default; + + WasmEdgeImage::ImgEnv &getEnv() { return Env; } + +private: + WasmEdgeImage::ImgEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/CMakeLists.txt b/plugins/wasmedge_llmc/CMakeLists.txt new file mode 100644 index 00000000..da37ed2f --- /dev/null +++ b/plugins/wasmedge_llmc/CMakeLists.txt @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeLLMC + SHARED + llmc_func.cpp + llmc_module.cpp + llmc_env.cpp +) + +option(WASMEDGE_PLUGIN_LLMC_CUDA "Training GPT2 with CUDA" OFF) + +message(STATUS "Start fetching llm.c source") +include(FetchContent) + +if (WASMEDGE_PLUGIN_LLMC_CUDA) + set(CUDALIB ON) + message(STATUS "Build wasmedge_llmc with CUDA backend") +else() + message(STATUS "Build wasmedge_llmc with CPU backend") +endif() + +FetchContent_Declare( + llmc + GIT_REPOSITORY https://github.com/WasmEdge/llm.c +) +FetchContent_MakeAvailable(llmc) + +if (WASMEDGE_PLUGIN_LLMC_CUDA) + target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE + train_gpt2_cuda + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE + train_gpt2_cpu + ) +endif() + +target_compile_options(wasmedgePluginWasmEdgeLLMC + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeLLMC + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeLLMC + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeLLMC + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeLLMC + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_llmc/llmc_base.h b/plugins/wasmedge_llmc/llmc_base.h new file mode 100644 index 00000000..6b35ffff --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_base.h @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "llmc_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeLLMC { + +template class HostFunction : public Runtime::HostFunction { +public: + HostFunction(LLMCEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + static constexpr uint32_t castErrNo(ErrNo E) noexcept { + return static_cast(E); + } + LLMCEnv &Env; +}; + +} // namespace WasmEdgeLLMC +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/llmc_env.cpp b/plugins/wasmedge_llmc/llmc_env.cpp new file mode 100644 index 00000000..a960ec16 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_env.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "llmc_env.h" +#include "llmc_fwd.h" +#include "llmc_module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeLLMC { + +uint32_t LLMCEnv::addModel(GPT2 *M) noexcept { + Models.push_back(M); + return Models.size() - 1; +} + +GPT2 *LLMCEnv::getModel(uint32_t Id) noexcept { + assert(Id < Models.size() && "Out of bounds"); + return Models[Id]; +} + +uint32_t LLMCEnv::addTokenizer(Tokenizer *T) noexcept { + Tokenizers.push_back(T); + return Tokenizers.size() - 1; +} + +Tokenizer *LLMCEnv::getTokenizer(uint32_t Id) noexcept { + assert(Id < Tokenizers.size() && "Out of bounds"); + return Tokenizers[Id]; +} + +uint32_t LLMCEnv::addDataLoader(DataLoader *D) noexcept { + DataLoaders.push_back(D); + return DataLoaders.size() - 1; +} + +DataLoader *LLMCEnv::getDataLoader(uint32_t Id) noexcept { + assert(Id < DataLoaders.size() && "Out of bounds"); + return DataLoaders[Id]; +} + +LLMCEnv::~LLMCEnv() { + for (GPT2 *M : Models) { + gpt2_destroy(M); + } + for (DataLoader *DL : DataLoaders) { + dataloader_destroy(DL); + } + for (Tokenizer *T : Tokenizers) { + tokenizer_destroy(T); + } +} + +namespace { +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeLLMCModule; +} + +static Plugin::PluginModule::ModuleDescriptor MD[] = { + { + /* Name */ "wasmedge_llmc", + /* Description */ "", + /* Create */ create, + }, +}; + +Plugin::Plugin::PluginDescriptor Descriptor{ + /* Name */ "wasmedge_llmc", + /* Description */ "", + /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, + /* Version */ {0, 1, 0, 0}, + /* ModuleCount */ 1, + /* ModuleDescriptions */ MD, + /* ComponentCount */ 0, + /* ComponentDescriptions */ nullptr, + /*AddOptions*/ nullptr, +}; +} // namespace + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace WasmEdgeLLMC +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/llmc_env.h b/plugins/wasmedge_llmc/llmc_env.h new file mode 100644 index 00000000..66b9b8c1 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_env.h @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include +#include +#include + +extern "C" { +struct GPT2; +struct Tokenizer; +struct DataLoader; +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeLLMC { + +enum class ErrNo : uint32_t { + Success = 0, + InvalidArgument = 1, + MissingMemory = 2, +}; + +class LLMCEnv { + std::vector Models; + std::vector Tokenizers; + std::vector DataLoaders; + +public: + uint32_t addModel(GPT2 *M) noexcept; + + GPT2 *getModel(uint32_t Id) noexcept; + + size_t getModelSize() const noexcept { return Models.size(); } + + uint32_t addTokenizer(Tokenizer *T) noexcept; + + Tokenizer *getTokenizer(uint32_t Id) noexcept; + + size_t getTokenizerSize() const noexcept { return Tokenizers.size(); } + + uint32_t addDataLoader(DataLoader *D) noexcept; + + DataLoader *getDataLoader(uint32_t Id) noexcept; + + size_t getDataLoaderSize() const noexcept { return DataLoaders.size(); } + + ~LLMCEnv(); +}; + +} // namespace WasmEdgeLLMC +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/llmc_func.cpp b/plugins/wasmedge_llmc/llmc_func.cpp new file mode 100644 index 00000000..b90fba78 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_func.cpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "llmc_func.h" +#include "llmc_fwd.h" + +#include "common/errcode.h" +#include "common/spdlog.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeLLMC { + +Expect ModelCreate::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, + uint32_t CheckPointPathLen, + uint32_t ModelIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; + } + auto CheckPointPathSpan = + MemInst->getSpan(CheckPointPath, CheckPointPathLen); + if (unlikely(CheckPointPathSpan.size() != CheckPointPathLen)) { + spdlog::error( + "[WasmEdge-LLMC] Failed when accessing the input checkpoint path memory."sv); + return ErrNo::MissingMemory; + } + + auto *ModelId = MemInst->getPointer(ModelIdPtr); + if (unlikely(ModelId == nullptr)) { + spdlog::error( + "[WasmEdge-LLMC] Failed when accessing the return model memory."sv); + return ErrNo::InvalidArgument; + } + std::string CheckPointPathStr = + std::string(CheckPointPathSpan.begin(), + CheckPointPathSpan.begin() + CheckPointPathSpan.size()); + GPT2 *Model = gpt2_create(CheckPointPathStr.data()); + *ModelId = Env.addModel(Model); + return ErrNo::Success; +} + +Expect DataLoaderCreate::bodyImpl( + const Runtime::CallingFrame &Frame, uint32_t DataPath, uint32_t DataPathLen, + uint32_t B, uint32_t T, uint32_t ProcessRank, uint32_t NumProcesses, + int32_t ShouldShuffle, uint32_t DataLoaderIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; + } + auto DataPathSpan = MemInst->getSpan(DataPath, DataPathLen); + if (unlikely(DataPathSpan.size() != DataPathLen)) { + spdlog::error( + "[WasmEdge-LLMC] Failed when accessing the input dataloader path memory."sv); + return ErrNo::MissingMemory; + } + + auto *DataLoaderId = MemInst->getPointer(DataLoaderIdPtr); + if (unlikely(DataLoaderId == nullptr)) { + spdlog::error( + "[WasmEdge-LLMC] Failed when accessing the return dataloader memory."sv); + return ErrNo::InvalidArgument; + } + + std::string DataPathStr = std::string( + DataPathSpan.begin(), DataPathSpan.begin() + DataPathSpan.size()); + DataLoader *D = dataloader_create(DataPathStr.data(), B, T, ProcessRank, + NumProcesses, ShouldShuffle); + *DataLoaderId = Env.addDataLoader(D); + return ErrNo::Success; +} + +Expect TokenizerCreate::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t FilePath, uint32_t FilePathLen, + uint32_t TokenizerIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; + } + auto FilePathSpan = MemInst->getSpan(FilePath, FilePathLen); + if (unlikely(FilePathSpan.size() != FilePathLen)) { + spdlog::error( + "[WasmEdge-LLMC] Failed when accessing the input tokenizer path memory."sv); + return ErrNo::MissingMemory; + } + + auto *TokenizerId = MemInst->getPointer(TokenizerIdPtr); + if (unlikely(TokenizerId == nullptr)) { + spdlog::error( + "[WasmEdge-LLMC] Failed when accessing the return tokenizer memory."sv); + return ErrNo::InvalidArgument; + } + std::string FilePathStr = std::string( + FilePathSpan.begin(), FilePathSpan.begin() + FilePathSpan.size()); + Tokenizer *T = tokenizer_create(FilePathStr.data()); + *TokenizerId = Env.addTokenizer(T); + return ErrNo::Success; +} + +Expect ModelTrain::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ModelId, uint32_t TrainDataLoaderId, + uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, + float Lr, uint32_t Epoch) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; + } + GPT2 *Model = Env.getModel(ModelId); + DataLoader *TrainDataLoader = Env.getDataLoader(TrainDataLoaderId); + DataLoader *ValDataLoader = Env.getDataLoader(ValDataLoaderId); + Tokenizer *Tokenizer = Env.getTokenizer(TokenizerId); + gpt2_train(Model, TrainDataLoader, ValDataLoader, Tokenizer, B, T, Lr, Epoch); + return ErrNo::Success; +} + +} // namespace WasmEdgeLLMC +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/llmc_func.h b/plugins/wasmedge_llmc/llmc_func.h new file mode 100644 index 00000000..85c786f5 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_func.h @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "llmc_base.h" +#include "llmc_env.h" + +#include "runtime/callingframe.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeLLMC { + +class ModelCreate : public HostFunction { +public: + explicit ModelCreate(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, uint32_t CheckPointPathLen, + uint32_t ModelIdPtr) { + return bodyImpl(Frame, CheckPointPath, CheckPointPathLen, ModelIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, uint32_t CheckPointPathLen, + uint32_t ModelIdPtr); +}; + +class DataLoaderCreate : public HostFunction { +public: + explicit DataLoaderCreate(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t DataPath, + uint32_t DataPathLen, uint32_t B, uint32_t T, + uint32_t ProcessRank, uint32_t NumProcesses, + int32_t ShouldShuffle, uint32_t DataLoaderIdPtr) { + return bodyImpl(Frame, DataPath, DataPathLen, B, T, ProcessRank, + NumProcesses, ShouldShuffle, DataLoaderIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, uint32_t DataPath, + uint32_t DataPathLen, uint32_t B, uint32_t T, + uint32_t ProcessRank, uint32_t NumProcesses, + int32_t ShouldShuffle, uint32_t DataLoaderIdPtr); +}; + +class TokenizerCreate : public HostFunction { +public: + explicit TokenizerCreate(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilePath, + uint32_t FilePathLen, uint32_t TokenizerIdPtr) { + return bodyImpl(Frame, FilePath, FilePathLen, TokenizerIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, uint32_t FilePath, + uint32_t FilePathLen, uint32_t TokenizerIdPtr); +}; + +class ModelTrain : public HostFunction { +public: + explicit ModelTrain(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelId, + uint32_t TrainDataLoaderId, uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, float Lr, + uint32_t Epoch) { + return bodyImpl(Frame, ModelId, TrainDataLoaderId, ValDataLoaderId, + TokenizerId, B, T, Lr, Epoch) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ModelId, + uint32_t TrainDataLoaderId, uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, float Lr, + uint32_t Epoch); +}; + +} // namespace WasmEdgeLLMC +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/llmc_fwd.h b/plugins/wasmedge_llmc/llmc_fwd.h new file mode 100644 index 00000000..afa39f10 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_fwd.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "llmc_env.h" + +extern "C" { + +struct GPT2; +struct Tokenizer; +struct DataLoader; + +GPT2 *gpt2_create(const char *checkpoint_path); + +void gpt2_destroy(GPT2 *model); + +DataLoader *dataloader_create(const char *filename_pattern, size_t B, size_t T, + int process_rank, int num_processes, + int should_shuffle); +void dataloader_destroy(DataLoader *loader); + +Tokenizer *tokenizer_create(const char *filename); + +void tokenizer_destroy(Tokenizer *tokenizer); + +void gpt2_train(GPT2 *model, DataLoader *train_loader, DataLoader *val_loader, + Tokenizer *tokenizer, int B, int T, float lr, int epoch); +} diff --git a/plugins/wasmedge_llmc/llmc_module.cpp b/plugins/wasmedge_llmc/llmc_module.cpp new file mode 100644 index 00000000..8914eb03 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_module.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "llmc_module.h" +#include "llmc_func.h" + +namespace WasmEdge { +namespace Host { + +WasmEdgeLLMCModule::WasmEdgeLLMCModule() : ModuleInstance("wasmedge_llmc") { + addHostFunc("model_create", std::make_unique(Env)); + addHostFunc("dataloader_create", + std::make_unique(Env)); + addHostFunc("tokenizer_create", + std::make_unique(Env)); + addHostFunc("model_train", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/llmc_module.h b/plugins/wasmedge_llmc/llmc_module.h new file mode 100644 index 00000000..86a923c3 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_module.h @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "llmc_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeLLMCModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeLLMCModule(); + +private: + WasmEdgeLLMC::LLMCEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/CMakeLists.txt b/plugins/wasmedge_ocr/CMakeLists.txt new file mode 100644 index 00000000..4079fce0 --- /dev/null +++ b/plugins/wasmedge_ocr/CMakeLists.txt @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeOCR + SHARED + ocr_env.cpp + ocr_func.cpp + ocr_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeOCR + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeOCR + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeOCR + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeOCR + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeOCR + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) + +message(STATUS "WASI-OCR: Build Tesseract backend for WASI-OCR") +find_package(PkgConfig REQUIRED) +pkg_search_module(TESSERACT REQUIRED tesseract) +pkg_search_module(LEPTONICA REQUIRED lept) + +target_include_directories(wasmedgePluginWasmEdgeOCR + PUBLIC + ${TESSERACT_INCLUDE_DIRS} + ${LEPTONICA_INCLUDE_DIRS} +) + +target_link_libraries(wasmedgePluginWasmEdgeOCR + PUBLIC + ${TESSERACT_LIBRARIES} + ${LEPTONICA_LIBRARIES} +) diff --git a/plugins/wasmedge_ocr/ocr_base.h b/plugins/wasmedge_ocr/ocr_base.h new file mode 100644 index 00000000..dc525dd2 --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "ocr_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeOCR { + +template class HostFunction : public Runtime::HostFunction { +public: + HostFunction(OCREnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + OCREnv &Env; +}; + +} // namespace WasmEdgeOCR +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/ocr_env.cpp b/plugins/wasmedge_ocr/ocr_env.cpp new file mode 100644 index 00000000..111b5136 --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_env.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#include "ocr_env.h" +#include "ocr_module.h" + +namespace WasmEdge { +namespace Host { +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeOCRModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_ocr", + .Description = "A WasmEdge Plugin for Optical Character Recognition (OCR) " + "powered by the Tesseract API.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_ocr", + .Description = + "A WasmEdge Plugin for Optical Character Recognition (OCR) " + "powered by the Tesseract API.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/ocr_env.h b/plugins/wasmedge_ocr/ocr_env.h new file mode 100644 index 00000000..231672a2 --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_env.h @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "common/spdlog.h" +#include "plugin/plugin.h" + +#include +#include + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeOCR { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + MissingMemory = 2, // Caller module is missing a memory export. + Busy = 3 // Device or resource busy. +}; + +class OCREnv { +public: + OCREnv() noexcept { + // Check the Tesseract API by initializing tesseract-ocr with English + // without specifying the tessdata path. + if (TesseractApi->Init(NULL, "eng")) { + spdlog::error( + "[WasmEdge-OCR] Error occurred when initializing tesseract."); + } + } + ~OCREnv() noexcept { + if (TesseractApi) { + TesseractApi->End(); + ; + } + } + tesseract::TessBaseAPI *TesseractApi = new tesseract::TessBaseAPI(); + + static Plugin::PluginRegister Register; +}; + +} // namespace WasmEdgeOCR +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/ocr_func.cpp b/plugins/wasmedge_ocr/ocr_func.cpp new file mode 100644 index 00000000..e75ef40f --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_func.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#include "ocr_func.h" + +#include "common/spdlog.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeOCR { + +Expect NumOfExtractions::body(const Runtime::CallingFrame &Frame, + uint32_t ImagePathPtr, + uint32_t ImagePathLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto ImagePtr = MemInst->getSpan(ImagePathPtr, ImagePathLen); + if (unlikely(ImagePtr.size() != ImagePathLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + Pix *image = pixRead(ImagePtr.data()); + + Env.TesseractApi->SetImage(image); + Env.TesseractApi->Recognize(0); + + tesseract::PageIteratorLevel level = tesseract::RIL_WORD; + const char *outText = Env.TesseractApi->GetTSVText(level); + + uint32_t length = strlen(outText); + pixDestroy(&image); + return static_cast(length); +} + +Expect GetOutput::body(const Runtime::CallingFrame &Frame, + uint32_t OutBufferPtr [[maybe_unused]], + uint32_t OutBufferMaxSize [[maybe_unused]]) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + // Check the return value: OutBufferPtr should be valid. + auto Buf = MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); + if (unlikely(Buf.empty())) { + spdlog::error( + "[WasmEdge-OCR] Failed when accessing the return OutBufferPtr memory."); + return static_cast(ErrNo::InvalidArgument); + } + + tesseract::PageIteratorLevel level = tesseract::RIL_WORD; + std::unique_ptr outText = Env.TesseractApi->GetTSVText(level); + std::copy_n(outText, std::min(std::strlen(outText.get()), Buf.size()), + Buf.begin()); + + // remaining free and deltee memory stuff + Env.TesseractApi->End(); + + return static_cast(ErrNo::Success); + // return outText; +} + +} // namespace WasmEdgeOCR +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/ocr_func.h b/plugins/wasmedge_ocr/ocr_func.h new file mode 100644 index 00000000..b00f191d --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_func.h @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "ocr_base.h" + +#include "runtime/callingframe.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeOCR { + +class NumOfExtractions : public HostFunction { +public: + NumOfExtractions(OCREnv &HostEnv) : HostFunction(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t ImagePathPtr, + uint32_t ImagePathLen); +}; + +class GetOutput : public HostFunction { +public: + GetOutput(OCREnv &HostEnv) : HostFunction(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize); +}; + +} // namespace WasmEdgeOCR +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/ocr_module.cpp b/plugins/wasmedge_ocr/ocr_module.cpp new file mode 100644 index 00000000..667c69c1 --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_module.cpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#include "ocr_module.h" +#include "ocr_func.h" + +namespace WasmEdge { +namespace Host { + +WasmEdgeOCRModule::WasmEdgeOCRModule() : ModuleInstance("wasmedge_ocr") { + addHostFunc("num_of_extractions", + std::make_unique(Env)); + addHostFunc("get_output", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/ocr_module.h b/plugins/wasmedge_ocr/ocr_module.h new file mode 100644 index 00000000..4f2b64c1 --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "ocr_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeOCRModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeOCRModule(); + + WasmEdgeOCR::OCREnv &getEnv() { return Env; } + +private: + WasmEdgeOCR::OCREnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/CMakeLists.txt b/plugins/wasmedge_opencvmini/CMakeLists.txt new file mode 100644 index 00000000..ed10e816 --- /dev/null +++ b/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +find_package(OpenCV 4 REQUIRED) + +wasmedge_add_library(wasmedgePluginWasmEdgeOpenCVMini + SHARED + opencvmini_env.cpp + opencvmini_func.cpp + opencvmini_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeOpenCVMini + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeOpenCVMini + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${OpenCV_INClUDE_DIRS} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeOpenCVMini + PRIVATE + wasmedgeCAPI + ${OpenCV_LIBS} + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeOpenCVMini + PRIVATE + wasmedge_shared + ${OpenCV_LIBS} + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeOpenCVMini + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_opencvmini/opencvmini_base.h b/plugins/wasmedge_opencvmini/opencvmini_base.h new file mode 100644 index 00000000..c9bc9d64 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "opencvmini_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template +class WasmEdgeOpenCVMini : public Runtime::HostFunction { +public: + WasmEdgeOpenCVMini(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgeOpenCVMiniEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.cpp b/plugins/wasmedge_opencvmini/opencvmini_env.cpp new file mode 100644 index 00000000..499c69e0 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_env.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "opencvmini_env.h" +#include "opencvmini_module.h" + +namespace WasmEdge { +namespace Host { + +WasmEdgeOpenCVMiniEnvironment::WasmEdgeOpenCVMiniEnvironment() noexcept {} + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeOpenCVMiniModule(); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_opencvmini", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_opencvmini", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.h b/plugins/wasmedge_opencvmini/opencvmini_env.h new file mode 100644 index 00000000..98c8b785 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_env.h @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgeOpenCVMiniEnvironment { +public: + WasmEdgeOpenCVMiniEnvironment() noexcept; + + std::map MatPool; + + Expect getMat(uint32_t MatKey) { + if (auto V = this->MatPool.find(MatKey); V != this->MatPool.end()) { + return V->second; + } else { + return Unexpect(ErrCode::Value::HostFuncError); + } + } + + Expect insertMat(const cv::Mat &Img) { + // cv::Mat::flags contains magic signature & I believe it's a good enough + // key for this purpose. + this->MatPool[static_cast(Img.flags)] = Img; + return static_cast(Img.flags); + } +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.cpp b/plugins/wasmedge_opencvmini/opencvmini_func.cpp new file mode 100644 index 00000000..9f31a61f --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_func.cpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "opencvmini_func.h" +#include "common/defines.h" +#include "common/errcode.h" + +#include +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +Expect +WasmEdgeOpenCVMiniImdecode::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t BufLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto Buf = MemInst->getSpan(BufPtr, BufLen); + if (unlikely(Buf.size() != BufLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + std::vector Content(Buf.begin(), Buf.end()); + cv::Mat Img = cv::imdecode(cv::InputArray(Content), cv::IMREAD_COLOR); + + return Env.insertMat(Img); +} + +Expect WasmEdgeOpenCVMiniImshow::body(const Runtime::CallingFrame &Frame, + uint32_t WindowNamePtr, + uint32_t WindowNameLen, + uint32_t MatKey) { + std::string WindowName; + + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto Buf = MemInst->getSpan(WindowNamePtr, WindowNameLen); + if (unlikely(Buf.size() != WindowNameLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::copy_n(Buf.data(), WindowNameLen, std::back_inserter(WindowName)); + + if (auto Img = Env.getMat(MatKey); Img) { + cv::imshow(WindowName.c_str(), *Img); + } + + return {}; +} + +Expect WasmEdgeOpenCVMiniWaitKey::body(const Runtime::CallingFrame &, + uint32_t Delay) { + cv::waitKey(static_cast(Delay)); + return {}; +} + +Expect WasmEdgeOpenCVMiniBlur::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelWidth, + uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::blur(*Src, Dst, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniBilateralFilter::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t D, + double SigmaColor, double SigmaSpace) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::bilateralFilter(*Src, Dst, D, SigmaColor, SigmaSpace); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniBoxFilter::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t Ddepth, + uint32_t KernelWidth, uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::boxFilter(*Src, Dst, Ddepth, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniEmptyMat::body(const Runtime::CallingFrame &) { + cv::Mat Kernel; + return Env.insertMat(Kernel); +} + +Expect WasmEdgeOpenCVMiniDilate::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelMatKey) { + cv::Mat Dst; + auto Kernel = Env.getMat(KernelMatKey); + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::dilate(*Src, Dst, *Kernel); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniErode::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelMatKey) { + cv::Mat Dst; + auto Kernel = Env.getMat(KernelMatKey); + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::erode(*Src, Dst, *Kernel); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniGaussianBlur::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t KernelWidth, + uint32_t KernelHeight, double SigmaX) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::GaussianBlur(*Src, Dst, cv::Size(KernelWidth, KernelHeight), SigmaX); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniLaplacian::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t Ddepth) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::Laplacian(*Src, Dst, Ddepth); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniMedianBlur::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t Ksize) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::medianBlur(*Src, Dst, Ksize); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniPyrDown::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelWidth, + uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::pyrDown(*Src, Dst, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniPyrUp::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelWidth, + uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::pyrUp(*Src, Dst, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniImwrite::body(const Runtime::CallingFrame &Frame, + uint32_t TargetFileNamePtr, + uint32_t TargetFileNameLen, + uint32_t MatKey) { + std::string TargetFileName; + + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto Buf = MemInst->getSpan(TargetFileNamePtr, TargetFileNameLen); + if (unlikely(Buf.size() != TargetFileNameLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::copy_n(Buf.data(), TargetFileNameLen, + std::back_inserter(TargetFileName)); + + if (auto Img = Env.getMat(MatKey); Img) { + cv::imwrite(TargetFileName.c_str(), *Img); + } + + return {}; +} + +Expect WasmEdgeOpenCVMiniImencode::body( + const Runtime::CallingFrame &Frame, uint32_t ExtPtr, uint32_t ExtLen, + uint32_t MatKey, uint32_t BufPtr, uint32_t BufLen) { + std::string Ext; + + auto *MemInst = Frame.getMemoryByIndex(0); + + auto Buf = MemInst->getSpan(ExtPtr, ExtLen); + if (unlikely(Buf.size() != ExtLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::copy_n(Buf.data(), ExtLen, std::back_inserter(Ext)); + + auto Img = Env.getMat(MatKey); + if (!Img) { + spdlog::error("[WasmEdge-OpenCVMini] "sv + "Failed to get matrix by key."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); + if (unlikely(OutSpan.size() != BufLen)) { + spdlog::error("[WasmEdge-OpenCVMini] "sv + "Failed when accessing the image target buffer memory."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + std::vector WriteTo; + cv::imencode(Ext, *Img, WriteTo); + + std::copy_n(WriteTo.begin(), WriteTo.size(), OutSpan.begin()); + + return {}; +} + +Expect +WasmEdgeOpenCVMiniNormalize::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + cv::Mat Dst; + // convert each elements `v` of `Src` to `(1/255) * v + 0` + Src->convertTo(Dst, CV_32F, 1. / 255., 0.); + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniBilinearSampling::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t OutImgW, + uint32_t OutImgH) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + cv::Mat Dst; + cv::resize(*Src, Dst, cv::Size(OutImgW, OutImgH), 0, 0, cv::INTER_LINEAR); + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniRectangle::body( + const Runtime::CallingFrame &, uint32_t SrcMatKey, uint32_t Top, + uint32_t Left, uint32_t Bot, uint32_t Right, double R, double G, double B, + int32_t Thickness, int32_t LineType, int32_t Shift) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + cv::Point TopLeft(Top, Left); + cv::Point BottomRight(Bot, Right); + + cv::rectangle(*Src, TopLeft, BottomRight, cv::Scalar(B, G, R), Thickness, + LineType, Shift); + return {}; +} + +Expect WasmEdgeOpenCVMiniCvtColor::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + int32_t Code, + int32_t DestChannelN) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto Img = *Src; + + cv::Mat Dst; + cvtColor(Img, Dst, Code, DestChannelN); + return Env.insertMat(Dst); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.h b/plugins/wasmedge_opencvmini/opencvmini_func.h new file mode 100644 index 00000000..587cc6a7 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_func.h @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "opencvmini_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { + +/// Read image from buffer +class WasmEdgeOpenCVMiniImdecode + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImdecode(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, + uint32_t BufLen); +}; + +/// Write image into buffer +class WasmEdgeOpenCVMiniImencode + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImencode(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t ExtPtr, + uint32_t ExtLen, uint32_t MatKey, uint32_t BufPtr, + uint32_t BufLen); +}; + +class WasmEdgeOpenCVMiniImshow + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImshow(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t WindowNamePtr, + uint32_t WindowNameLen, uint32_t MatKey); +}; + +class WasmEdgeOpenCVMiniWaitKey + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniWaitKey(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t Delay); +}; + +class WasmEdgeOpenCVMiniBlur + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBlur(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight); +}; + +class WasmEdgeOpenCVMiniBilateralFilter + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBilateralFilter(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey, + uint32_t D, double SigmaColor, double SigmaSpace); +}; + +class WasmEdgeOpenCVMiniImwrite + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImwrite(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, + uint32_t TargetFileNamePtr, uint32_t TargetFileNameLen, + uint32_t SrcMatKey); +}; + +/// This is not `cv::normalize`; refer to: +/// https://github.com/WasmEdge/WasmEdge/commit/77051da4995d7318d91a82102a72ce2557151764#diff-3333d926ca87cf4285bfcd6deae45ee310307be66fca8a4ca6f0f8a946743fccR50-R54 +class WasmEdgeOpenCVMiniNormalize + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniNormalize(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey); +}; + +class WasmEdgeOpenCVMiniBilinearSampling + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBilinearSampling(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey, + uint32_t OutImgW, uint32_t OutImgH); +}; + +class WasmEdgeOpenCVMiniBoxFilter + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBoxFilter(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Ddepth, uint32_t KernelWidth, + uint32_t KernelHeight); +}; + +class WasmEdgeOpenCVMiniEmptyMat + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniEmptyMat(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &); +}; + +class WasmEdgeOpenCVMiniDilate + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniDilate(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelMatKey); +}; + +class WasmEdgeOpenCVMiniErode + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniErode(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Kernel); +}; + +class WasmEdgeOpenCVMiniGaussianBlur + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniGaussianBlur(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight, + double SigmaX); +}; + +class WasmEdgeOpenCVMiniLaplacian + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniLaplacian(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Ddepth); +}; + +class WasmEdgeOpenCVMiniMedianBlur + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniMedianBlur(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Ksize); +}; + +class WasmEdgeOpenCVMiniPyrDown + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniPyrDown(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight); +}; + +class WasmEdgeOpenCVMiniPyrUp + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniPyrUp(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight); +}; + +class WasmEdgeOpenCVMiniRectangle + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniRectangle(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Top, uint32_t Left, uint32_t Bot, uint32_t Right, + double R, double G, double B, int32_t Thickness, + int32_t LineType, int32_t Shift); +}; + +class WasmEdgeOpenCVMiniCvtColor + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniCvtColor(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + int32_t Code, int32_t DestChannelN); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.cpp b/plugins/wasmedge_opencvmini/opencvmini_module.cpp new file mode 100644 index 00000000..b4c60ed5 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_module.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "opencvmini_module.h" +#include "opencvmini_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeOpenCVMiniModule::WasmEdgeOpenCVMiniModule() + : ModuleInstance("wasmedge_opencvmini") { + addHostFunc("wasmedge_opencvmini_imdecode", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_imencode", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_imwrite", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_blur", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_bilateral_filter", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_box_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmin_dilate", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_erode", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_gaussian_blur", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_Laplacian", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_median_blur", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_pyrDown", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_pyrUp", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_normalize", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_bilinear_sampling", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_cvt_color", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_rectangle", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_imshow", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_waitkey", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.h b/plugins/wasmedge_opencvmini/opencvmini_module.h new file mode 100644 index 00000000..3d9296a2 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "opencvmini_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeOpenCVMiniModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeOpenCVMiniModule(); + + WasmEdgeOpenCVMiniEnvironment &getEnv() { return Env; } + +private: + WasmEdgeOpenCVMiniEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt new file mode 100644 index 00000000..28a4bcce --- /dev/null +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeProcess + SHARED + processenv.cpp + processfunc.cpp + processmodule.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeProcess + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeProcess + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeProcess + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeProcess + PRIVATE + wasmedge_shared + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeProcess + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_process/processbase.h b/plugins/wasmedge_process/processbase.h new file mode 100644 index 00000000..f7d9fe6e --- /dev/null +++ b/plugins/wasmedge_process/processbase.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "processenv.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasmEdgeProcess : public Runtime::HostFunction { +public: + WasmEdgeProcess(WasmEdgeProcessEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgeProcessEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processenv.cpp b/plugins/wasmedge_process/processenv.cpp new file mode 100644 index 00000000..774989d6 --- /dev/null +++ b/plugins/wasmedge_process/processenv.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "processenv.h" +#include "processmodule.h" + +#include "po/helper.h" + +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals::string_view_literals; + +PO::List WasmEdgeProcessEnvironment::AllowCmd( + PO::Description( + "Allow commands called from wasmedge_process host functions. Each command can be specified as --allow-command `COMMAND`."sv), + PO::MetaVar("COMMANDS"sv)); + +PO::Option WasmEdgeProcessEnvironment::AllowCmdAll(PO::Description( + "Allow all commands called from wasmedge_process host functions."sv)); + +WasmEdgeProcessEnvironment::WasmEdgeProcessEnvironment() noexcept + : AllowedCmd(AllowCmd.value().begin(), AllowCmd.value().end()), + AllowedAll(AllowCmdAll.value()) {} + +namespace { + +void addOptions(const Plugin::Plugin::PluginDescriptor *, + PO::ArgumentParser &Parser) noexcept { + Parser.add_option("allow-command"sv, WasmEdgeProcessEnvironment::AllowCmd) + .add_option("allow-command-all"sv, + WasmEdgeProcessEnvironment::AllowCmdAll); +} + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeProcessModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_process", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_process", + .Description = "", + .Create = create, + }, + }, + .AddOptions = addOptions, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processenv.h b/plugins/wasmedge_process/processenv.h new file mode 100644 index 00000000..a8a98d92 --- /dev/null +++ b/plugins/wasmedge_process/processenv.h @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "common/hash.h" +#include "plugin/plugin.h" +#include "po/argument_parser.h" +#include "po/list.h" +#include "po/option.h" + +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgeProcessEnvironment { +public: + WasmEdgeProcessEnvironment() noexcept; + + /// Default timeout in milliseconds. + static inline const uint32_t DEFAULT_TIMEOUT = 10000; + /// Default polling time in milliseconds. + static inline const uint32_t DEFAULT_POLLTIME = 1; + + /// Commands + std::string Name; + std::vector Args; + std::unordered_map Envs; + + /// IO + std::vector StdIn; + std::vector StdOut; + std::vector StdErr; + + /// Configurations + /// Timeout in milliseconds. + uint32_t TimeOut = DEFAULT_TIMEOUT; + /// Programs in the allowlist. + std::unordered_set AllowedCmd; + /// Flag to allow all programs. + bool AllowedAll; + + /// Results + uint32_t ExitCode = 0; + + static PO::List AllowCmd; + static PO::Option AllowCmdAll; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp new file mode 100644 index 00000000..8cd5d942 --- /dev/null +++ b/plugins/wasmedge_process/processfunc.cpp @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "processfunc.h" + +#include "common/defines.h" + +#if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#elif WASMEDGE_OS_WINDOWS +#endif + +using namespace std::literals; + +namespace WasmEdge { +namespace Host { + +Expect +WasmEdgeProcessSetProgName::body(const Runtime::CallingFrame &Frame, + uint32_t NamePtr, uint32_t NameLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto Buf = MemInst->getSpan(NamePtr, NameLen); + std::copy(Buf.begin(), Buf.end(), std::back_inserter(Env.Name)); + return {}; +} + +Expect WasmEdgeProcessAddArg::body(const Runtime::CallingFrame &Frame, + uint32_t ArgPtr, uint32_t ArgLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto Buf = MemInst->getSpan(ArgPtr, ArgLen); + std::string NewArg; + std::copy(Buf.begin(), Buf.end(), std::back_inserter(NewArg)); + Env.Args.push_back(std::move(NewArg)); + return {}; +} + +Expect WasmEdgeProcessAddEnv::body(const Runtime::CallingFrame &Frame, + uint32_t EnvNamePtr, + uint32_t EnvNameLen, + uint32_t EnvValPtr, + uint32_t EnvValLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto EnvBuf = MemInst->getSpan(EnvNamePtr, EnvNameLen); + const auto ValBuf = MemInst->getSpan(EnvValPtr, EnvValLen); + std::string NewEnv, NewVal; + std::copy(EnvBuf.begin(), EnvBuf.end(), std::back_inserter(NewEnv)); + std::copy(ValBuf.begin(), ValBuf.end(), std::back_inserter(NewVal)); + Env.Envs.emplace(std::move(NewEnv), std::move(NewVal)); + return {}; +} + +Expect WasmEdgeProcessAddStdIn::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t BufLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto const Buf = MemInst->getSpan(BufPtr, BufLen); + Env.StdIn.reserve(Env.StdIn.size() + BufLen); + std::copy(Buf.begin(), Buf.end(), std::back_inserter(Env.StdIn)); + return {}; +} + +Expect WasmEdgeProcessSetTimeOut::body(const Runtime::CallingFrame &, + uint32_t Time) { + Env.TimeOut = Time; + return {}; +} + +Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { +#if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS + // Clear outputs. + Env.StdOut.clear(); + Env.StdErr.clear(); + Env.ExitCode = static_cast(-1); + + // Check the command allowlist. + if (!Env.AllowedAll && + Env.AllowedCmd.find(Env.Name) == Env.AllowedCmd.end()) { + std::string Msg = "Permission denied: Command \""; + Msg.append(Env.Name); + Msg.append("\" is not in the white list. Please use --allow-command="); + Msg.append(Env.Name); + Msg.append(" or --allow-command-all to add \""); + Msg.append(Env.Name); + Msg.append("\" command into the white list.\n"); + Env.Name.clear(); + Env.Args.clear(); + Env.Envs.clear(); + Env.StdIn.clear(); + Env.StdErr.reserve(Msg.length()); + std::copy_n(Msg.c_str(), Msg.length(), std::back_inserter(Env.StdErr)); + Env.ExitCode = static_cast(INT8_C(-1)); + Env.TimeOut = Env.DEFAULT_TIMEOUT; + return Env.ExitCode; + } + + // Create pipes for stdin, stdout, and stderr. + int FDStdIn[2], FDStdOut[2], FDStdErr[2]; + if (pipe(FDStdIn) == -1) { + // Create stdin pipe failed. + return Env.ExitCode; + } + if (pipe(FDStdOut) == -1) { + // Create stdout pipe failed. + close(FDStdIn[0]); + close(FDStdIn[1]); + return Env.ExitCode; + } + if (pipe(FDStdErr) == -1) { + // Create stderr pipe failed. + close(FDStdIn[0]); + close(FDStdIn[1]); + close(FDStdOut[0]); + close(FDStdOut[1]); + return Env.ExitCode; + } + + // Create a child process for executing a command. + pid_t PID = fork(); + if (PID == -1) { + // Create process failed. + close(FDStdIn[0]); + close(FDStdIn[1]); + close(FDStdOut[0]); + close(FDStdOut[1]); + close(FDStdErr[0]); + close(FDStdErr[1]); + return Env.ExitCode; + } else if (PID == 0) { + // Child process. Setup pipes. + dup2(FDStdIn[0], 0); + dup2(FDStdOut[1], 1); + dup2(FDStdErr[1], 2); + close(FDStdIn[0]); + close(FDStdIn[1]); + close(FDStdOut[0]); + close(FDStdOut[1]); + close(FDStdErr[0]); + close(FDStdErr[1]); + + // Prepare arguments and environment variables. + std::vector EnvStr; + for (auto &It : Env.Envs) { + EnvStr.push_back(It.first + "=" + It.second); + } + std::vector Argv, Envp; + Argv.push_back(Env.Name.data()); + std::transform(Env.Args.begin(), Env.Args.end(), std::back_inserter(Argv), + [](std::string &S) { return S.data(); }); + std::transform(EnvStr.begin(), EnvStr.end(), std::back_inserter(Envp), + [](std::string &S) { return S.data(); }); + Argv.push_back(nullptr); + Envp.push_back(nullptr); +#if defined(__GLIBC_PREREQ) +#if __GLIBC_PREREQ(2, 11) + if (execvpe(Env.Name.c_str(), &Argv[0], &Envp[0]) == -1) { +#else + if (execve(Env.Name.c_str(), &Argv[0], &Envp[0]) == -1) { +#endif +#else + if (execve(Env.Name.c_str(), &Argv[0], &Envp[0]) == -1) { +#endif + switch (errno) { + case EACCES: + spdlog::error("Permission denied."sv); + break; + case ENOENT: + spdlog::error("Command not found."sv); + break; + default: + spdlog::error("Unknown error."sv); + break; + } + _exit(-1); + } + } else { + // Parent process. Close unused file descriptors. + close(FDStdIn[0]); + close(FDStdOut[1]); + close(FDStdErr[1]); + + // Send inputs. + uint32_t WBytes = 0; + while (WBytes < Env.StdIn.size()) { + uint32_t WriteNum = + std::min(static_cast(PIPE_BUF), Env.StdIn.size() - WBytes); + if (auto Res = write(FDStdIn[1], &Env.StdIn[WBytes], WriteNum); Res > 0) { + WBytes += Res; + } else { + break; + } + } + close(FDStdIn[1]); + + // Waiting for child process and get outputs. + uint8_t Buf[PIPE_BUF]; + ssize_t RBytes; + int ChildStat; + struct timeval TStart, TCurr; + gettimeofday(&TStart, NULL); + while (true) { + gettimeofday(&TCurr, NULL); + if ((TCurr.tv_sec - TStart.tv_sec) * 1000U + + (TCurr.tv_usec - TStart.tv_usec) / 1000U > + Env.TimeOut) { + // Over timeout. Interrupt child process. + kill(PID, SIGKILL); + Env.ExitCode = static_cast(ETIMEDOUT); + break; + } + + // Wait for child process. + pid_t WPID = waitpid(PID, &ChildStat, WNOHANG); + if (WPID == -1) { + // waitpid failed. + Env.ExitCode = static_cast(EINVAL); + break; + } else if (WPID > 0) { + // Child process returned. + Env.ExitCode = static_cast(WEXITSTATUS(ChildStat)); + break; + } + + // Read stdout and stderr. + fd_set FDSet; + int NFD = std::max(FDStdOut[0], FDStdErr[0]) + 1; + FD_ZERO(&FDSet); + FD_SET(FDStdOut[0], &FDSet); + FD_SET(FDStdErr[0], &FDSet); + struct timeval TSelect = {.tv_sec = 0, .tv_usec = 0}; + if (select(NFD, &FDSet, NULL, NULL, &TSelect) > 0) { + if (FD_ISSET(FDStdOut[0], &FDSet)) { + if (RBytes = read(FDStdOut[0], Buf, sizeof(Buf)); RBytes > 0) { + Env.StdOut.reserve(Env.StdOut.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdOut)); + } + } + if (FD_ISSET(FDStdErr[0], &FDSet)) { + if (RBytes = read(FDStdErr[0], Buf, sizeof(Buf)); RBytes > 0) { + Env.StdErr.reserve(Env.StdErr.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdErr)); + } + } + } + usleep(Env.DEFAULT_POLLTIME * 1000); + } + + // Read remaining stdout and stderr. + do { + RBytes = read(FDStdOut[0], Buf, sizeof(Buf)); + if (RBytes > 0) { + Env.StdOut.reserve(Env.StdOut.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdOut)); + } + } while (RBytes > 0); + do { + RBytes = read(FDStdErr[0], Buf, sizeof(Buf)); + if (RBytes > 0) { + Env.StdErr.reserve(Env.StdErr.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdErr)); + } + } while (RBytes > 0); + close(FDStdOut[0]); + close(FDStdErr[0]); + } + + // Reset inputs. + Env.Name.clear(); + Env.Args.clear(); + Env.Envs.clear(); + Env.StdIn.clear(); + Env.TimeOut = Env.DEFAULT_TIMEOUT; + return Env.ExitCode; +#elif WASMEDGE_OS_WINDOWS + spdlog::error("wasmedge_process doesn't support windows now."sv); + return Unexpect(ErrCode::Value::HostFuncError); +#endif +} + +Expect +WasmEdgeProcessGetExitCode::body(const Runtime::CallingFrame &) { + return Env.ExitCode; +} + +Expect +WasmEdgeProcessGetStdOutLen::body(const Runtime::CallingFrame &) { + return static_cast(Env.StdOut.size()); +} + +Expect WasmEdgeProcessGetStdOut::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto Buf = MemInst->getSpan(BufPtr, Env.StdOut.size()); + std::copy_n(Env.StdOut.begin(), std::min(Env.StdOut.size(), Buf.size()), + Buf.begin()); + return {}; +} + +Expect +WasmEdgeProcessGetStdErrLen::body(const Runtime::CallingFrame &) { + return static_cast(Env.StdErr.size()); +} + +Expect WasmEdgeProcessGetStdErr::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto Buf = MemInst->getSpan(BufPtr, Env.StdErr.size()); + std::copy_n(Env.StdErr.begin(), std::min(Env.StdErr.size(), Buf.size()), + Buf.begin()); + return {}; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processfunc.h b/plugins/wasmedge_process/processfunc.h new file mode 100644 index 00000000..9746d433 --- /dev/null +++ b/plugins/wasmedge_process/processfunc.h @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "processbase.h" + +#include "runtime/callingframe.h" + +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgeProcessSetProgName + : public WasmEdgeProcess { +public: + WasmEdgeProcessSetProgName(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen); +}; + +class WasmEdgeProcessAddArg : public WasmEdgeProcess { +public: + WasmEdgeProcessAddArg(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ArgPtr, + uint32_t ArgLen); +}; + +class WasmEdgeProcessAddEnv : public WasmEdgeProcess { +public: + WasmEdgeProcessAddEnv(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t EnvNamePtr, + uint32_t EnvNameLen, uint32_t EnvValPtr, + uint32_t EnvValLen); +}; + +class WasmEdgeProcessAddStdIn + : public WasmEdgeProcess { +public: + WasmEdgeProcessAddStdIn(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, + uint32_t BufLen); +}; + +class WasmEdgeProcessSetTimeOut + : public WasmEdgeProcess { +public: + WasmEdgeProcessSetTimeOut(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Time); +}; + +class WasmEdgeProcessRun : public WasmEdgeProcess { +public: + WasmEdgeProcessRun(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class WasmEdgeProcessGetExitCode + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetExitCode(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class WasmEdgeProcessGetStdOutLen + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdOutLen(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class WasmEdgeProcessGetStdOut + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdOut(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr); +}; + +class WasmEdgeProcessGetStdErrLen + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdErrLen(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class WasmEdgeProcessGetStdErr + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdErr(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processmodule.cpp b/plugins/wasmedge_process/processmodule.cpp new file mode 100644 index 00000000..613be81d --- /dev/null +++ b/plugins/wasmedge_process/processmodule.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "processmodule.h" +#include "processfunc.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeProcessModule::WasmEdgeProcessModule() + : ModuleInstance("wasmedge_process") { + addHostFunc("wasmedge_process_set_prog_name", + std::make_unique(Env)); + addHostFunc("wasmedge_process_add_arg", + std::make_unique(Env)); + addHostFunc("wasmedge_process_add_env", + std::make_unique(Env)); + addHostFunc("wasmedge_process_add_stdin", + std::make_unique(Env)); + addHostFunc("wasmedge_process_set_timeout", + std::make_unique(Env)); + addHostFunc("wasmedge_process_run", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_exit_code", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stdout_len", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stdout", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stderr_len", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stderr", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processmodule.h b/plugins/wasmedge_process/processmodule.h new file mode 100644 index 00000000..6482ee68 --- /dev/null +++ b/plugins/wasmedge_process/processmodule.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "processenv.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeProcessModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeProcessModule(); + + WasmEdgeProcessEnvironment &getEnv() { return Env; } + +private: + WasmEdgeProcessEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt new file mode 100644 index 00000000..1e3038b9 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + + +if(WASMEDGE_PLUGIN_STABLEDIFFUSION_CUDA) + message(STATUS "Stable diffusion plugin: Enable SD_CUDA") + set(SD_CUDA ON CACHE BOOL "Stable diffusion plugin: Enable SD_CUDA") +else() + message(STATUS "Stable diffusion plugin: Disable SD_CUDA") + set(SD_CUDA OFF CACHE BOOL "Stable diffusion plugin: Disable SD_CUDA") +endif() + +if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND WASMEDGE_PLUGIN_STABLEDIFFUSION_METAL) + message(STATUS "Stable diffusion plugin: Enable SD_METAL") + set(SD_METAL ON CACHE BOOL "Stable diffusion plugin: Enable SD_METAL") + set(GGML_METAL_EMBED_LIBRARY ON) +else() + message(STATUS "Stable diffusion plugin: Disable SD_METAL") + set(SD_METAL OFF CACHE BOOL "Stable diffusion plugin: Disable SD_METAL") +endif() + +if(WASMEDGE_PLUGIN_STABLEDIFFUSION_OPENMP) + message(STATUS "Stable diffusion plugin: Enable SD_OPENMP") + set(GGML_OPENMP ON) +else() + message(STATUS "Stable diffusion plugin: Disable SD_OPENMP") + set(GGML_OPENMP OFF) +endif() + +# setup stable diffusion +message(STATUS "Downloading stable diffusion source") +FetchContent_Declare( + stable-diffusion + GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git + GIT_TAG dcf91f9e0f2cbf9da472ee2a556751ed4bab2d2a + GIT_SHALLOW TRUE + ) +set(SD_BUILD_SHARED_LIBS ON CACHE INTERNAL "Stable diffusion plugin: Build shared libs") +FetchContent_MakeAvailable(stable-diffusion) +set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) +if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) + # `cblas_sgemm()` introduced in macOS 13.3. + set(GGML_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion plugin: Turn off accelerate") +endif() + + +wasmedge_add_library(wasmedgePluginWasmEdgeStableDiffusion + SHARED + sd_env.cpp + sd_func.cpp + sd_module.cpp +) + +target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion + PRIVATE + stable-diffusion + ${CMAKE_THREAD_LIBS_INIT} +) + +target_compile_options(wasmedgePluginWasmEdgeStableDiffusion + PUBLIC + -DWASMEDGE_PLUGIN +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion + PRIVATE + wasmedge_shared + ) +endif() + +target_include_directories(wasmedgePluginWasmEdgeStableDiffusion + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_include_directories(wasmedgePluginWasmEdgeStableDiffusion + SYSTEM + PRIVATE + "${stable-diffusion_SOURCE_DIR}/thirdparty" +) + +if (MSVC) + target_compile_options( + stable-diffusion + PRIVATE + /wd4459 + /wd4100 + /wd4127 + /wd4701 + ) +else() + target_compile_options( + stable-diffusion + PRIVATE + -Wno-unused-function + -Wno-unused-variable + -Wno-unused-parameter + -Wno-missing-field-initializers + -Wno-deprecated-declarations + -Wno-braced-scalar-init + -Wno-unused-value + -Wno-uninitialized + -Wno-format + -Wno-enum-compare + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeStableDiffusion + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_stablediffusion/sd_base.h b/plugins/wasmedge_stablediffusion/sd_base.h new file mode 100644 index 00000000..5ba7441c --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_base.h @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "sd_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { + +template class Func : public Runtime::HostFunction { +public: + Func(SDEnviornment &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + static constexpr uint32_t castErrNo(StableDiffusion::ErrNo E) noexcept { + return static_cast(E); + } + SDEnviornment &Env; +}; + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp new file mode 100644 index 00000000..438857e4 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "sd_env.h" +#include "sd_module.h" + +using namespace std::literals; + +namespace WasmEdge { +namespace Host { +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new SDModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_stablediffusion", + .Description = "Stable Diffusion plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 4, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_stablediffusion", + .Description = + "This module contains Stable Diffusion host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace + +namespace StableDiffusion { + +uint32_t SDEnviornment::addContext(sd_ctx_t *Ctx, int32_t Nthreads, + uint32_t Wtype) noexcept { + Contexts.push_back({Ctx, Nthreads, Wtype}); + return Contexts.size() - 1; +} + +void SDEnviornment::freeContext(const uint32_t Id) noexcept { + sd_ctx_t *SDCtx = Contexts[Id].Context; + free_sd_ctx(SDCtx); + Contexts.erase(Contexts.begin() + Id - 1); +} + +sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { + if (Id >= Contexts.size()) { + return nullptr; + } + return Contexts[Id].Context; +} + +void SBLog(enum sd_log_level_t Level, const char *Log, void *) { + if (!Log) { + return; + } + std::string LevelStr; + switch (Level) { + case SD_LOG_DEBUG: + LevelStr = "DEBUG"; + break; + case SD_LOG_INFO: + LevelStr = "INFO"; + break; + case SD_LOG_WARN: + LevelStr = "WARN"; + break; + case SD_LOG_ERROR: + LevelStr = "ERROR"; + break; + default: + LevelStr = "?????"; + break; + } + + spdlog::info("[WasmEdge-StableDiffusion] SD-log: [{}] {}"sv, LevelStr, Log); +} + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h new file mode 100644 index 00000000..b58dd28f --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "stable-diffusion.h" + +#include "plugin/plugin.h" +#include + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { + +void SBLog(enum sd_log_level_t Level, const char *Log, void *); + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. +}; + +struct ContextInfo { + sd_ctx_t *Context; + int32_t NThreads; + uint32_t Wtype; +}; + +class SDEnviornment { +public: + SDEnviornment() noexcept { + if (EnableSDLog) { + sd_set_log_callback(SBLog, nullptr); + } + }; + uint32_t addContext(sd_ctx_t *Ctx, int32_t Nthreads, uint32_t Wtype) noexcept; + void freeContext(const uint32_t Id) noexcept; + sd_ctx_t *getContext(const uint32_t Id) noexcept; + size_t getContextSize() noexcept { return Contexts.size(); } + int32_t getNThreads(const uint32_t Id) noexcept { + return Contexts[Id].NThreads; + } + uint32_t getWtype(const uint32_t Id) noexcept { return Contexts[Id].Wtype; } + +private: + bool EnableSDLog = false; + std::vector Contexts; +}; + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp new file mode 100644 index 00000000..9ad24fc4 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -0,0 +1,591 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "sd_func.h" +#include "common/spdlog.h" +#include "sd_env.h" +#include "spdlog/spdlog.h" +#include "stable-diffusion.h" + +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC +#include "stb_image.h" + +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_IMAGE_WRITE_STATIC +#include "stb_image_write.h" + +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#define STB_IMAGE_RESIZE_STATIC +#include "stb_image_resize.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-StableDiffusion] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define SESSION_CHECK(Out, SessionID, Message, ErrNo) \ + auto *Out = Env.getContext(SessionID); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_SV_CHECK(OutSV, MemInst, BufPtr, BufLen, Message) \ + auto OutSV = MemInst->getStringView(BufPtr, BufLen); \ + if (unlikely(OutSV.size() != BufLen)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointer(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +bool parameterCheck(SDEnviornment &Env, uint32_t Width, uint32_t Height, + uint32_t SessionId) { + if (SessionId >= Env.getContextSize()) { + spdlog::error("[WasmEdge-StableDiffusion] Session ID is invalid."sv); + return false; + } + if (Width % 64 != 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] Width must be a multiple of 64 and greater than 0."sv); + return false; + } + if (Height % 64 != 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] Height must be a multiple of 64 and greater than 0."sv); + return false; + } + return true; +} + +sd_image_t *readControlImage(Span ControlImage, int Width, int Height, + bool CannyPreprocess) { + uint8_t *ControlImageBuffer = nullptr; + sd_image_t *ControlImg = nullptr; + int Channel = 0; + std::string ControlImagePath(ControlImage.begin(), ControlImage.end()); + + if (ControlImagePath.substr(0, 5) == "path:"sv) { + ControlImageBuffer = stbi_load(ControlImagePath.substr(5).data(), &Width, + &Height, &Channel, 3); + } else { + ControlImageBuffer = stbi_load_from_memory( + ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); + } + + if (ControlImageBuffer == nullptr) { + spdlog::error( + "[WasmEdge-StableDiffusion] Load image from control image failed."sv); + return nullptr; + } + ControlImg = + new sd_image_t{static_cast(Width), + static_cast(Height), 3, ControlImageBuffer}; + if (CannyPreprocess) { // apply preprocessor + ControlImg->data = + preprocess_canny(ControlImg->data, ControlImg->width, + ControlImg->height, 0.08f, 0.08f, 0.8f, 1.0f, false); + } + free(ControlImageBuffer); + return ControlImg; +} + +sd_image_t readMaskImage(Span MaskImage, int Width, int Height) { + uint8_t *MaskImageBuffer = NULL; + std::string MaskImagePath(MaskImage.begin(), MaskImage.end()); + int Channel = 0; + if (MaskImagePath.substr(0, 5) == "path:"sv) { + MaskImageBuffer = + stbi_load(MaskImagePath.substr(5).data(), &Width, &Height, &Channel, 3); + } else if (MaskImage.size() != 0) { + MaskImageBuffer = stbi_load_from_memory(MaskImage.data(), MaskImage.size(), + &Width, &Height, &Channel, 3); + } else { + std::vector Arr(Width * Height, 255); + MaskImageBuffer = Arr.data(); + } + return {static_cast(Width), static_cast(Height), 1, + MaskImageBuffer}; +} + +void upscalerModel(const char *UpscaleModelPath, uint32_t UpscaleRepeats, + int32_t NThreads, uint32_t BatchCount, sd_image_t *Results) { + // unused for RealESRGAN_x4plus_anime_6B.pth + int UpscaleFactor = 4; + upscaler_ctx_t *UpscalerCtx = new_upscaler_ctx(UpscaleModelPath, NThreads); + if (UpscalerCtx == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Create upscaler ctx failed."sv); + } else { + for (uint32_t I = 0; I < BatchCount; I++) { + if (Results[I].data == nullptr) { + continue; + } + sd_image_t CurrentImage = Results[I]; + for (uint32_t U = 0; U < UpscaleRepeats; ++U) { + sd_image_t UpscaledImage = + upscale(UpscalerCtx, CurrentImage, UpscaleFactor); + if (UpscaledImage.data == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Upscale failed."sv); + break; + } + free(CurrentImage.data); + CurrentImage = UpscaledImage; + } + // Set the final upscaled image as the result. + Results[I] = CurrentImage; + } + } + free(UpscalerCtx); +} + +bool saveResults(sd_image_t *Results, uint32_t BatchCount, + std::string OutputPath, uint32_t OutputPathLen, + uint32_t *BytesWritten, uint32_t OutBufferMaxSize, + uint8_t *OutputBufferSpanPtr) { + int Len; + unsigned char *Png = stbi_write_png_to_mem( + reinterpret_cast(Results), 0, Results->width, + Results->height, Results->channel, &Len, nullptr); + if (OutputPathLen != 0) { + size_t Last = OutputPath.find_last_of("."); + std::string DummyName = OutputPath; + std::string Extension = ".png"; + if (Last != std::string::npos) { + std::string LastStr = OutputPath.substr(Last); + if (LastStr == ".png" || LastStr == ".PNG") { + DummyName = OutputPath.substr(0, Last); + Extension = LastStr; + } + } + for (uint32_t I = 0; I < BatchCount; I++) { + if (Results[I].data != nullptr) { + std::string FinalImagePath; + if (I <= 0) + FinalImagePath = DummyName + Extension; + else + FinalImagePath += "_" + std::to_string(I + 1) + Extension; + stbi_write_png(FinalImagePath.c_str(), Results[I].width, + Results[I].height, Results[I].channel, Results[I].data, + 0, nullptr); + spdlog::info("[WasmEdge-StableDiffusion] Save result image to {}."sv, + FinalImagePath.c_str()); + free(Results[I].data); + Results[I].data = nullptr; + } + } + } + *BytesWritten = Len; + if (OutBufferMaxSize < *BytesWritten) { + spdlog::error("[WasmEdge-StableDiffusion] Output buffer is not enough."sv); + free(Png); + free(Results); + return false; + } + std::copy_n(Png, *BytesWritten, OutputBufferSpanPtr); + free(Png); + free(Results); + return true; +} + +Expect SDConvert::body(const Runtime::CallingFrame &Frame, + uint32_t ModelPathPtr, uint32_t ModelPathLen, + uint32_t VaeModelPathPtr, + uint32_t VaeModelPathLen, + uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t WType) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input parameter value. + MEM_SPAN_CHECK(ModelPathSpan, MemInst, char, ModelPathPtr, ModelPathLen, + "Failed when accessing the input model path memory."sv) + MEM_SPAN_CHECK(VaeModelPathSpan, MemInst, char, VaeModelPathPtr, + VaeModelPathLen, + "Failed when accessing the input vae model path memory."sv) + MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, + "Failed when accessing the output path memory."sv) + std::string ModelPath = std::string( + ModelPathSpan.begin(), ModelPathSpan.begin() + ModelPathSpan.size()); + std::string VaeModelPath = + std::string(VaeModelPathSpan.begin(), + VaeModelPathSpan.begin() + VaeModelPathSpan.size()); + std::string OutputPath = std::string( + OutputPathSpan.begin(), OutputPathSpan.begin() + OutputPathSpan.size()); + + spdlog::info("[WasmEdge-StableDiffusion] Convert model: {} to {}."sv, + ModelPath.data(), OutputPath.data()); + std::ifstream Fin(ModelPath.data(), std::ios::in | std::ios::binary); + if (!Fin) { + Fin.close(); + spdlog::error("[WasmEdge-StableDiffusion] Model not found."sv); + return static_cast(ErrNo::InvalidArgument); + } + Fin.close(); + // Convert model. + bool Ret = ::convert(ModelPath.data(), VaeModelPath.data(), OutputPath.data(), + static_cast(WType)); + if (!Ret) { + spdlog::error("[WasmEdge-StableDiffusion] Failed to convert model."sv); + return static_cast(ErrNo::InvalidArgument); + } + + return static_cast(ErrNo::Success); +} + +Expect SDCreateContext::body( + const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, + uint32_t ModelPathLen, uint32_t clipLPathPtr, uint32_t clipLPathLen, + uint32_t clipGPathPtr, uint32_t clipGPathLen, uint32_t t5xxlPathPtr, + uint32_t t5xxlPathLen, uint32_t diffusionModelPathPtr, + uint32_t diffusionModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, + uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, + uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, + uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, + uint32_t VaeTiling, int32_t NThreads, uint32_t Wtype, uint32_t RngType, + uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, + uint32_t VaeOnCpu, uint32_t DiffusionFlashAttn, uint32_t SessiontIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + // Check the input model buffer. + MEM_SPAN_CHECK(ModelPathSpan, MemInst, char, ModelPathPtr, ModelPathLen, + "Failed when accessing the input model path memory."sv) + MEM_SPAN_CHECK(clipLPathSpan, MemInst, char, clipLPathPtr, clipLPathLen, + "Failed when accessing the input clipL path memory."sv) + MEM_SPAN_CHECK(clipGPathSpan, MemInst, char, clipGPathPtr, clipGPathLen, + "Failed when accessing the input clipG path memory."sv) + MEM_SPAN_CHECK(t5xxlPathSpan, MemInst, char, t5xxlPathPtr, t5xxlPathLen, + "Failed when accessing the input t5xxl path memory."sv) + MEM_SPAN_CHECK( + diffusionModelPathSpan, MemInst, char, diffusionModelPathPtr, + diffusionModelPathLen, + "Failed when accessing the input diffusion model path memory."sv) + MEM_SPAN_CHECK(VaePathSpan, MemInst, char, VaePathPtr, VaePathLen, + "Failed when accessing the input vae path memory."sv) + MEM_SPAN_CHECK(ControlNetPathSpan, MemInst, char, ControlNetPathPtr, + ControlNetPathLen, + "Failed when accessing the input control net path memory."sv) + MEM_SPAN_CHECK(LoraModelDirSpan, MemInst, char, LoraModelDirPtr, + LoraModelDirLen, + "Failed when accessing the input lora model path memory."sv) + MEM_SPAN_CHECK(TaesdPathSpan, MemInst, char, TaesdPathPtr, TaesdPathLen, + "Failed when accessing the input taesd path memory."sv) + MEM_SPAN_CHECK(EmbedDirSpan, MemInst, char, EmbedDirPtr, EmbedDirLen, + "Failed when accessing the input embedded directory memory."sv) + MEM_SPAN_CHECK( + IdEmbedDirSpan, MemInst, char, IdEmbedDirPtr, IdEmbedDirLen, + "Failed when accessing the input id dembed directory memory."sv) + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessiontIdPtr, + "Failed when accessing the return SessionID memory."sv) + std::string ModelPath = + std::string(ModelPathSpan.begin(), ModelPathSpan.end()); + std::string VaePath = std::string(VaePathSpan.begin(), VaePathSpan.end()); + std::string TaesdPath = + std::string(TaesdPathSpan.begin(), TaesdPathSpan.end()); + std::string ControlNetPath = + std::string(ControlNetPathSpan.begin(), ControlNetPathSpan.end()); + std::string LoraModelDir = + std::string(LoraModelDirSpan.begin(), LoraModelDirSpan.end()); + std::string EmbedDir = std::string(EmbedDirSpan.begin(), EmbedDirSpan.end()); + std::string IdEmbedDir = + std::string(IdEmbedDirSpan.begin(), IdEmbedDirSpan.end()); + std::string clipLPath = + std::string(clipLPathSpan.begin(), clipLPathSpan.end()); + std::string clipGPath = + std::string(clipGPathSpan.begin(), clipGPathSpan.end()); + std::string t5xxlPath = + std::string(t5xxlPathSpan.begin(), t5xxlPathSpan.end()); + std::string diffusionModelPath = + std::string(diffusionModelPathSpan.begin(), diffusionModelPathSpan.end()); + if (NThreads == -1) { + NThreads = get_num_physical_cores(); + } + // Check parameters + if (ModelPathLen == 0 && diffusionModelPathLen == 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] The following arguments are required: ModelPath / DiffusionModelPath"sv); + return static_cast(ErrNo::InvalidArgument); + } + // Create context and import graph. + spdlog::debug("[WasmEdge-StableDiffusion] Create context."sv); + sd_ctx_t *Ctx = new_sd_ctx( + ModelPath.data(), clipLPath.data(), clipGPath.data(), t5xxlPath.data(), + diffusionModelPath.data(), VaePath.data(), TaesdPath.data(), + ControlNetPath.data(), LoraModelDir.data(), EmbedDir.data(), + IdEmbedDir.data(), static_cast(VaeDecodeOnly), + static_cast(VaeTiling), false, NThreads, + static_cast(Wtype), static_cast(RngType), + static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu, + DiffusionFlashAttn); + if (Ctx == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."sv); + return static_cast(ErrNo::InvalidArgument); + } + *SessionId = Env.addContext(Ctx, NThreads, static_cast(Wtype)); + return static_cast(ErrNo::Success); +} + +Expect SDTextToImage::body( + const Runtime::CallingFrame &Frame, uint32_t PromptPtr, uint32_t PromptLen, + uint32_t SessionId, uint32_t ControlImagePtr, uint32_t ControlImageLen, + uint32_t NegativePromptPtr, uint32_t NegativePromptLen, float Guidance, + uint32_t Width, uint32_t Height, int32_t ClipSkip, float CfgScale, + uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Seed, + uint32_t BatchCount, float ControlStrength, float StyleRatio, + uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, + uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, + uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, + uint32_t UpscaleRepeats, uint32_t SkipLayersPtr, uint32_t SkipLayersLen, + float SlgScale, float SkipLayerStart, float SkipLayerEnd, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + // Check the input model buffer. + MEM_SPAN_CHECK(PromptSpan, MemInst, char, PromptPtr, PromptLen, + "Failed when accessing the promp memory."sv) + MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, + NegativePromptLen, + "Failed when accessing the input negative prompt memory."sv) + MEM_SPAN_CHECK(InputIdImagesDirSpan, MemInst, char, InputIdImagesDirPtr, + InputIdImagesDirLen, + "Failed when accessing the input id images path memory."sv) + MEM_SPAN_CHECK(OutputBufferSpan, MemInst, uint8_t, OutBufferPtr, + OutBufferMaxSize, + "Failed when accessing the Output Buffer memory."sv) + MEM_PTR_CHECK(BytesWritten, MemInst, uint32_t, BytesWrittenPtr, + "Failed when accessing the return bytes written memory."sv) + MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, + "Failed when accessing the output path memory."sv) + MEM_SPAN_CHECK(SkipLayersSpan, MemInst, int32_t, SkipLayersPtr, SkipLayersLen, + "Failed when accessing the SkipLayers memory."sv) + std::string Prompt(PromptSpan.begin(), PromptSpan.end()); + std::string NegativePrompt(NegativePromptSpan.begin(), + NegativePromptSpan.end()); + std::string InputIdImagesDir(InputIdImagesDirSpan.begin(), + InputIdImagesDirSpan.end()); + std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); + if (!parameterCheck(Env, Width, Height, SessionId)) { + return static_cast(ErrNo::InvalidArgument); + } + SESSION_CHECK(SDCtx, SessionId, "Session ID is invalid."sv, + ErrNo::InvalidArgument) + sd_image_t *Results = nullptr; + sd_image_t *ControlImage = nullptr; + // Read control image + if (ControlImageLen != 0) { + MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, + ControlImageLen, + "Failed when accessing the control image memory."sv) + ControlImage = + readControlImage(ControlImageSpan, Width, Height, CannyPreprocess); + } + // Generate images + spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); + Results = txt2img( + SDCtx, Prompt.data(), NegativePrompt.data(), ClipSkip, CfgScale, Guidance, + Width, Height, sample_method_t(SampleMethod), SampleSteps, Seed, + BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, + InputIdImagesDir.data(), SkipLayersSpan.data(), SkipLayersSpan.size(), + SlgScale, SkipLayerStart, SkipLayerEnd); + free(ControlImage); + if (Results == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Generate failed."sv); + Env.freeContext(SessionId); + return static_cast(ErrNo::RuntimeError); + } + // Upscale image + if (UpscaleModelPathLen != 0) { + MEM_SPAN_CHECK(UpscaleModelSpan, MemInst, char, UpscaleModelPathPtr, + UpscaleModelPathLen, + "Failed when accessing the Upscaler Image memory."sv) + std::string UpscaleModelPath(UpscaleModelSpan.begin(), + UpscaleModelSpan.end()); + upscalerModel(UpscaleModelPath.data(), UpscaleRepeats, + Env.getNThreads(SessionId), BatchCount, Results); + } + // Save results + if (!saveResults(Results, BatchCount, OutputPath, OutputPathLen, BytesWritten, + OutBufferMaxSize, OutputBufferSpan.data())) { + return static_cast(ErrNo::RuntimeError); + } + return static_cast(ErrNo::Success); +} + +Expect SDImageToImage::body( + const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, + uint32_t MaskImagePtr, uint32_t MaskImageLen, uint32_t SessionId, + float Guidance, uint32_t Width, uint32_t Height, uint32_t ControlImagePtr, + uint32_t ControlImageLen, uint32_t PromptPtr, uint32_t PromptLen, + uint32_t NegativePromptPtr, uint32_t NegativePromptLen, int32_t ClipSkip, + float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, float Strength, + uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, + uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, + uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, + uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, + uint32_t UpscaleRepeats, uint32_t SkipLayersPtr, uint32_t SkipLayersLen, + float SlgScale, float SkipLayerStart, float SkipLayerEnd, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + // Check the input parameter valid. + MEM_SPAN_CHECK(ImageSpan, MemInst, uint8_t, ImagePtr, ImageLen, + "Failed when accessing the input image memory."sv) + MEM_SPAN_CHECK(MaskImageSpan, MemInst, uint8_t, MaskImagePtr, MaskImageLen, + "Failed when accessing the input mask image memory."sv) + MEM_SPAN_CHECK(PromptSpan, MemInst, char, PromptPtr, PromptLen, + "Failed when accessing the promp memory."sv) + MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, + NegativePromptLen, + "Failed when accessing the input negative prompt memory."sv) + MEM_SPAN_CHECK(InputIdImagesDirSpan, MemInst, char, InputIdImagesDirPtr, + InputIdImagesDirLen, + "Failed when accessing the input id images path memory."sv) + MEM_SPAN_CHECK(OutputBufferSpan, MemInst, uint8_t, OutBufferPtr, + OutBufferMaxSize, + "Failed when accessing the Output Buffer memory."sv) + MEM_PTR_CHECK(BytesWritten, MemInst, uint32_t, BytesWrittenPtr, + "Failed when accessing the return bytes written memory."sv) + MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, + "Failed when accessing the output path memory."sv) + MEM_SPAN_CHECK(SkipLayersSpan, MemInst, int32_t, SkipLayersPtr, SkipLayersLen, + "Failed when accessing the SkipLayers memory."sv) + if (!parameterCheck(Env, Width, Height, SessionId)) { + return static_cast(ErrNo::InvalidArgument); + } + SESSION_CHECK(SDCtx, SessionId, "Session ID is invalid."sv, + ErrNo::InvalidArgument) + std::string Prompt(PromptSpan.begin(), PromptSpan.end()); + std::string NegativePrompt(NegativePromptSpan.begin(), + NegativePromptSpan.end()); + std::string InputIdImagesDir(InputIdImagesDirSpan.begin(), + InputIdImagesDirSpan.end()); + std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); + // Read input image + uint8_t *InputImageBuffer = nullptr; + int Channel = 0; + int ImageWidth = 0; + int ImageHeight = 0; + std::string ImagePath(ImageSpan.begin(), ImageSpan.end()); + if (ImagePath.substr(0, 5) == "path:"sv) { + InputImageBuffer = stbi_load(ImagePath.substr(5).data(), &ImageWidth, + &ImageHeight, &Channel, 3); + if (InputImageBuffer == nullptr) { + spdlog::error( + "[WasmEdge-StableDiffusion] Load image from input image failed."sv); + return static_cast(ErrNo::InvalidArgument); + } + if (Channel < 3) { + spdlog::error( + "[WasmEdge-StableDiffusion] The number of channels for the input image must be >= 3."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + if (ImageWidth <= 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] The width of image must be greater than 0."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + if (ImageHeight <= 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] The height of image must be greater than 0."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + // Resize image when its size does not match the width and height. + if (Height != static_cast(ImageHeight) || + Width != static_cast(ImageWidth)) { + int ResizedHeight = Height; + int ResizedWidth = Width; + uint8_t *ResizedImageBuffer = + (uint8_t *)malloc(ResizedHeight * ResizedWidth * 3); + if (ResizedImageBuffer == nullptr) { + spdlog::error( + "[WasmEdge-StableDiffusion] Failed to allocate memory for resize input image."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + stbir_resize(InputImageBuffer, ImageWidth, ImageHeight, 0, + ResizedImageBuffer, ResizedWidth, ResizedHeight, 0, + STBIR_TYPE_UINT8, 3, STBIR_ALPHA_CHANNEL_NONE, 0, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_FILTER_BOX, + STBIR_FILTER_BOX, STBIR_COLORSPACE_SRGB, nullptr); + free(InputImageBuffer); + InputImageBuffer = ResizedImageBuffer; + } + } else { + InputImageBuffer = + stbi_load_from_memory(ImageSpan.data(), ImageSpan.size(), &ImageWidth, + &ImageHeight, &Channel, 3); + } + sd_image_t InputImage = {Width, Height, 3, InputImageBuffer}; + // Read control image + sd_image_t *ControlImage = nullptr; + if (ControlImageLen != 0) { + MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, + ControlImageLen, + "Failed when accessing the control image memory."sv) + ControlImage = + readControlImage(ControlImageSpan, Width, Height, CannyPreprocess); + } + // Read mask image + sd_image_t MaskImage = readMaskImage(MaskImageSpan, Width, Height); + // Generate images + sd_image_t *Results = nullptr; + spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); + Results = + img2img(SDCtx, InputImage, MaskImage, Prompt.data(), + NegativePrompt.data(), ClipSkip, CfgScale, Guidance, Width, + Height, sample_method_t(SampleMethod), SampleSteps, Strength, + Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, + NormalizeInput, InputIdImagesDir.data(), SkipLayersSpan.data(), + SkipLayersSpan.size(), SlgScale, SkipLayerStart, SkipLayerEnd); + free(ControlImage); + free(InputImageBuffer); + if (Results == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Generate failed."sv); + Env.freeContext(SessionId); + return static_cast(ErrNo::RuntimeError); + } + // Upscale image + if (UpscaleModelPathLen != 0) { + MEM_SPAN_CHECK(UpscaleModelSpan, MemInst, char, UpscaleModelPathPtr, + UpscaleModelPathLen, + "Failed when accessing the Upscaler Image memory."sv) + std::string UpscaleModelPath(UpscaleModelSpan.begin(), + UpscaleModelSpan.end()); + upscalerModel(UpscaleModelPath.data(), UpscaleRepeats, + Env.getNThreads(SessionId), BatchCount, Results); + } + // Save results + if (!saveResults(Results, BatchCount, OutputPath, OutputPathLen, BytesWritten, + OutBufferMaxSize, OutputBufferSpan.data())) { + return static_cast(ErrNo::RuntimeError); + } + return static_cast(ErrNo::Success); +} + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h new file mode 100644 index 00000000..aa9285bc --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "runtime/callingframe.h" +#include "sd_base.h" +#include "stable-diffusion.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { + +class SDCreateContext : public StableDiffusion::Func { +public: + SDCreateContext(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + Expect + body(const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, + uint32_t ModelPathLen, uint32_t clipLPathPtr, uint32_t clipLPathLen, + uint32_t clipGPathPtr, uint32_t clipGPathLen, uint32_t t5xxlPathPtr, + uint32_t t5xxlPathLen, uint32_t diffusionModelPathPtr, + uint32_t diffusionModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, + uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, + uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, + uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, + uint32_t VaeTiling, int32_t NThreads, uint32_t Wtype, uint32_t RngType, + uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, + uint32_t VaeOnCpu, uint32_t DiffusionFlashAttn, uint32_t SessiontIdPtr); +}; + +class SDImageToImage : public StableDiffusion::Func { +public: + SDImageToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + Expect + body(const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, + uint32_t MaskImagePtr, uint32_t MaskImageLen, uint32_t SessionId, + float Guidance, uint32_t Width, uint32_t Height, + uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, + uint32_t PromptLen, uint32_t NegativePromptPtr, + uint32_t NegativePromptLen, int32_t ClipSkip, float CfgScale, + uint32_t SampleMethod, uint32_t SampleSteps, float Strength, + uint32_t Seed, uint32_t BatchCount, float ControlStrength, + float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, + uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, + uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, + uint32_t UpscaleRepeats, uint32_t SkipLayersPtr, uint32_t SkipLayersLen, + float SlgScale, float SkipLayerStart, float SkipLayerEnd, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); +}; + +class SDTextToImage : public StableDiffusion::Func { +public: + SDTextToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + Expect + body(const Runtime::CallingFrame &Frame, uint32_t PromptPtr, + uint32_t PromptLen, uint32_t SessionId, uint32_t ControlImagePtr, + uint32_t ControlImageLen, uint32_t NegativePromptPtr, + uint32_t NegativePromptLen, float Guidance, uint32_t Width, + uint32_t Height, int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, + uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, + float ControlStrength, float StyleRatio, uint32_t NormalizeInput, + uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, + uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, + uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, + uint32_t SkipLayersPtr, uint32_t SkipLayersLen, float SlgScale, + float SkipLayerStart, float SkipLayerEnd, uint32_t OutputPathPtr, + uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); +}; + +class SDConvert : public StableDiffusion::Func { +public: + SDConvert(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t ModelPathPtr, uint32_t ModelPathLen, + uint32_t VaeModelPathPtr, uint32_t VaeModelPathLen, + uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t WType); +}; + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp new file mode 100644 index 00000000..a568c4ab --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "sd_module.h" +#include "sd_func.h" + +namespace WasmEdge { +namespace Host { + +SDModule::SDModule() : ModuleInstance("wasmedge_stablediffusion") { + addHostFunc("create_context", + std::make_unique(Env)); + addHostFunc("image_to_image", + std::make_unique(Env)); + addHostFunc("text_to_image", + std::make_unique(Env)); + addHostFunc("convert", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_module.h b/plugins/wasmedge_stablediffusion/sd_module.h new file mode 100644 index 00000000..bfc7ba72 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_module.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "sd_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class SDModule : public Runtime::Instance::ModuleInstance { +public: + SDModule(); + StableDiffusion::SDEnviornment &getEnv() { return Env; } + +private: + StableDiffusion::SDEnviornment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/CMakeLists.txt b/plugins/wasmedge_tensorflow/CMakeLists.txt new file mode 100644 index 00000000..ccfe25ed --- /dev/null +++ b/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeTensorflow + SHARED + tensorflow_env.cpp + tensorflow_func.cpp + tensorflow_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeTensorflow + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeTensorflow + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeTensorflow + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeTensorflow + PRIVATE + wasmedge_shared + ) +endif() + +include(WASINNDeps) +wasmedge_setup_tf_target(wasmedgePluginWasmEdgeTensorflow) + +install( + TARGETS wasmedgePluginWasmEdgeTensorflow + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_tensorflow/tensorflow_base.h b/plugins/wasmedge_tensorflow/tensorflow_base.h new file mode 100644 index 00000000..fb17fec5 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "tensorflow_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +template class Func : public Runtime::HostFunction { +public: + Func(TFEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + TFEnv &Env; +}; + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.cpp b/plugins/wasmedge_tensorflow/tensorflow_env.cpp new file mode 100644 index 00000000..98312b14 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_env.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tensorflow_env.h" +#include "tensorflow_module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeTensorflowModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_tensorflow", + .Description = "Tensorflow plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 13, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_tensorflow", + .Description = + "This module contains WasmEdge-Tensorflow host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.h b/plugins/wasmedge_tensorflow/tensorflow_env.h new file mode 100644 index 00000000..5fd4ef3c --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_env.h @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include "tensorflow/c/c_api.h" + +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. +}; + +struct TensorList { + void reset() noexcept { + for (uint32_t I = 0; I < DataList.size(); ++I) { + if (DataList[I]) { + TF_DeleteTensor(DataList[I]); + } + } + NameMap.clear(); + OperList.clear(); + DataList.clear(); + } + + std::unordered_map NameMap; + std::vector OperList; + std::vector DataList; +}; + +struct Context { + Context() noexcept { Stat = TF_NewStatus(); } + ~Context() noexcept { + reset(); + TF_DeleteStatus(Stat); + } + + void clearInputs() noexcept { Inputs.reset(); } + + void clearOutputs() noexcept { Outputs.reset(); } + + void reset() noexcept { + if (GraphOpts) { + TF_DeleteImportGraphDefOptions(GraphOpts); + GraphOpts = nullptr; + } + if (Buffer) { + TF_DeleteBuffer(Buffer); + Buffer = nullptr; + } + if (Graph) { + TF_DeleteGraph(Graph); + Graph = nullptr; + } + if (SessionOpts) { + TF_DeleteSessionOptions(SessionOpts); + SessionOpts = nullptr; + } + if (Session) { + TF_CloseSession(Session, Stat); + TF_DeleteSession(Session, Stat); + Session = nullptr; + } + clearInputs(); + clearOutputs(); + } + + TF_Status *Stat; + TF_ImportGraphDefOptions *GraphOpts = nullptr; + TF_Buffer *Buffer = nullptr; + TF_Graph *Graph = nullptr; + TF_SessionOptions *SessionOpts = nullptr; + TF_Session *Session = nullptr; + struct TensorList Inputs; + struct TensorList Outputs; +}; + +struct TFEnv { + TFEnv() noexcept { TFContext.reserve(16U); } + + Context *getContext(const uint32_t ID) noexcept { + auto It = RecycledIdx.find(ID); + if (ID < TFContext.size() && It == RecycledIdx.end()) { + return &TFContext[ID]; + } + return nullptr; + } + uint32_t newContext() noexcept { + uint32_t NewIdx = TFContext.size(); + if (RecycledIdx.empty()) { + TFContext.emplace_back(); + } else { + NewIdx = *RecycledIdx.begin(); + RecycledIdx.erase(NewIdx); + } + return NewIdx; + } + void deleteContext(const uint32_t ID) noexcept { + if (ID < TFContext.size()) { + TFContext[ID].reset(); + RecycledIdx.insert(ID); + } + } + +private: + std::unordered_set RecycledIdx; + std::vector TFContext; +}; + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.cpp b/plugins/wasmedge_tensorflow/tensorflow_func.cpp new file mode 100644 index 00000000..2b8e6cd3 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_func.cpp @@ -0,0 +1,440 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tensorflow_func.h" + +#include "common/span.h" +#include "common/spdlog.h" + +#include "tensorflow/c/c_api.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +using namespace std::literals::string_view_literals; + +namespace { + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define SESSION_CHECK(Out, SessionID, Message, ErrNo) \ + auto *Out = Env.getContext(SessionID); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_SV_CHECK(OutSV, MemInst, BufPtr, BufLen, Message) \ + auto OutSV = MemInst->getStringView(BufPtr, BufLen); \ + if (unlikely(OutSV.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointer(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +std::pair parseIndex(std::string_view Name) { + // Check if there's index in the string key. + size_t Pos = Name.find(":"); + int Idx = 0; + std::string NameStr; + if (Pos != std::string::npos) { + Idx = std::strtol(Name.data() + Pos + 1, nullptr, 10); + NameStr = Name.substr(0, Pos); + } else { + NameStr = Name; + } + return std::make_pair(NameStr, Idx); +} + +} // namespace + +Expect CreateSession::body(const Runtime::CallingFrame &Frame, + uint32_t ModBufPtr, uint32_t ModBufLen, + uint32_t SessionIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input model buffer. + MEM_SPAN_CHECK(ModBufSpan, MemInst, char, ModBufPtr, ModBufLen, + "Failed when accessing the input model buffer memory."sv) + + // Check the return value: SessionIdPtr should be valid. + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessionIdPtr, + "Failed when accessing the return SessionID memory."sv) + + // Create context and import graph. + uint32_t NewID = Env.newContext(); + SESSION_CHECK(Cxt, NewID, "Failed when allocating resources."sv, + ErrNo::MissingMemory) + + Cxt->Graph = TF_NewGraph(); + Cxt->Buffer = TF_NewBufferFromString(ModBufSpan.data(), ModBufLen); + Cxt->GraphOpts = TF_NewImportGraphDefOptions(); + TF_GraphImportGraphDef(Cxt->Graph, Cxt->Buffer, Cxt->GraphOpts, Cxt->Stat); + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Cannot import graph from buffer: {}"sv, + TF_Message(Cxt->Stat)); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + + // Create session. + Cxt->SessionOpts = TF_NewSessionOptions(); + Cxt->Session = TF_NewSession(Cxt->Graph, Cxt->SessionOpts, Cxt->Stat); + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Unable to create session: {}"sv, + TF_Message(Cxt->Stat)); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + + *SessionId = NewID; + return static_cast(ErrNo::Success); +} + +Expect CreateSessionSavedModel::body( + const Runtime::CallingFrame &Frame, uint32_t PathPtr, uint32_t PathLen, + uint32_t TagsBufPtr, uint32_t TagsBufLen, uint32_t SessionIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the model path buffer. + MEM_SV_CHECK(PathSV, MemInst, PathPtr, PathLen, + "Failed when accessing the model path buffer memory."sv) + + // Check the tags buffer. + struct MetaGraphDefTag { + uint32_t Ptr; + uint32_t Len; + }; + MEM_SPAN_CHECK(TagSpan, MemInst, MetaGraphDefTag, TagsBufPtr, TagsBufLen, + "Failed when accessing the tags memory."sv) + + // Check the elements of tags. + std::vector Tags; + std::vector TagsArgv; + Tags.reserve(TagsBufLen); + TagsArgv.reserve(TagsBufLen); + for (size_t I = 0; I < TagSpan.size(); ++I) { + // Use std::string to copy the tag name here and avoid relying on + // null-termination of the tag strings. + const auto &Tag = TagSpan[I]; + MEM_SV_CHECK(TagNameSV, MemInst, Tag.Ptr, Tag.Len, + "Failed when accessing the tag name memory."sv) + Tags.emplace_back(TagNameSV); + TagsArgv.emplace_back(Tags.back().c_str()); + } + + // Check the return value: SessionIdPtr should be valid. + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessionIdPtr, + "Failed when accessing the return SessionID memory."sv) + + // Create context and import graph. + uint32_t NewID = Env.newContext(); + SESSION_CHECK(Cxt, NewID, "Failed when allocating resources."sv, + ErrNo::MissingMemory) + + // Create session. + Cxt->Graph = TF_NewGraph(); + Cxt->GraphOpts = TF_NewImportGraphDefOptions(); + Cxt->SessionOpts = TF_NewSessionOptions(); + Cxt->Session = TF_LoadSessionFromSavedModel( + Cxt->SessionOpts, nullptr, std::string(PathSV).c_str(), TagsArgv.data(), + TagsArgv.size(), Cxt->Graph, nullptr, Cxt->Stat); + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Unable to create session: {}"sv, + TF_Message(Cxt->Stat)); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + + *SessionId = NewID; + return static_cast(ErrNo::Success); +} + +Expect DeleteSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + Env.deleteContext(SessionId); + return static_cast(ErrNo::Success); +} + +Expect RunSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Delete old output tensors + for (auto T : Cxt->Outputs.DataList) { + if (T) { + TF_DeleteTensor(T); + } + } + + // Run session + TF_SessionRun(Cxt->Session, + // RunOptions + nullptr, + // Input tensors + Cxt->Inputs.OperList.data(), Cxt->Inputs.DataList.data(), + Cxt->Inputs.DataList.size(), + // Output tensors + Cxt->Outputs.OperList.data(), Cxt->Outputs.DataList.data(), + Cxt->Outputs.DataList.size(), + // Target operations + nullptr, 0, + // RunMetadata + nullptr, + // Output status + Cxt->Stat); + + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Run session failed: {}"sv, + TF_Message(Cxt->Stat)); + return static_cast(ErrNo::Busy); + } + return static_cast(ErrNo::Success); +} + +Expect GetOutputTensor::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t TensorIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the output name buffer memory."sv) + + // Check the return value: TensorIdPtr should be valid. + MEM_PTR_CHECK(TensorId, MemInst, uint32_t, TensorIdPtr, + "Failed when accessing the return TensorID memory."sv) + + // Find the output tensor ID. + auto It = Cxt->Outputs.NameMap.find(std::string(NameSV)); + if (unlikely(It == Cxt->Outputs.NameMap.end())) { + spdlog::error("[WasmEdge-Tensorflow] Output tensor {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + *TensorId = It->second; + return static_cast(ErrNo::Success); +} + +Expect GetTensorLen::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t LenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the return value: LenPtr should be valid. + MEM_PTR_CHECK(Len, MemInst, uint32_t, LenPtr, + "Failed when accessing the return Length memory."sv) + + // Get output tensor from ID. + if (unlikely(TensorId >= Cxt->Outputs.DataList.size())) { + spdlog::error("[WasmEdge-Tensorflow] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Return tensor data length. + auto *Tensor = Cxt->Outputs.DataList[TensorId]; + if (likely(Tensor != nullptr)) { + *Len = TF_TensorByteSize(Tensor); + } else { + *Len = 0U; + } + return static_cast(ErrNo::Success); +} + +Expect GetTensorData::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the output tensor buffer. + MEM_SPAN_CHECK( + BufSpan, MemInst, char, BufPtr, BufLen, + "Failed when accessing the output tensor write buffer memory."sv) + + // Check the return value: WrittenBytesPtr should be valid. + MEM_PTR_CHECK(WrittenBytes, MemInst, uint32_t, WrittenBytesPtr, + "Failed when accessing the return WrittenBytes memory."sv) + + // Get output tensor from ID. + if (unlikely(TensorId >= Cxt->Outputs.DataList.size())) { + spdlog::error("[WasmEdge-Tensorflow] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Copy tensor data to buffer. + auto *Tensor = Cxt->Outputs.DataList[TensorId]; + size_t RealSize = TF_TensorByteSize(Tensor); + *WrittenBytes = 0U; + if (Tensor != nullptr && RealSize > 0 && BufLen > 0) { + *WrittenBytes = std::min(static_cast(RealSize), BufLen); + char *Data = static_cast(TF_TensorData(Tensor)); + std::copy_n(Data, *WrittenBytes, BufSpan.data()); + } + return static_cast(ErrNo::Success); +} + +Expect AppendInput::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t DimPtr, + uint32_t DimCnt, uint32_t DataType, + uint32_t TensorBufPtr, + uint32_t TensorBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor buffer. + MEM_SPAN_CHECK(TensorBufSpan, MemInst, uint8_t, TensorBufPtr, TensorBufLen, + "Failed when accessing the input tensor buffer memory."sv) + + // Check the input tensor dimension buffer. + MEM_SPAN_CHECK(DimBufSpan, MemInst, int64_t, DimPtr, DimCnt, + "Failed when accessing the input dimension buffer memory."sv) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the input name buffer memory."sv) + + // Check the input operation. + auto OperKeyPair = parseIndex(NameSV); + TF_Operation *Operation = + TF_GraphOperationByName(Cxt->Graph, OperKeyPair.first.c_str()); + if (unlikely(Operation == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow] Input operation {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + + // Check if the input tensor by name exists. + uint32_t TensorId = Cxt->Inputs.DataList.size(); + auto It = Cxt->Inputs.NameMap.find(std::string(NameSV)); + if (It != Cxt->Inputs.NameMap.end()) { + TensorId = It->second; + } + + // Create the tensor and copy data from buffer. + TF_Tensor *Tensor = nullptr; + if (DimCnt > 0) { + Tensor = TF_AllocateTensor(static_cast(DataType), + DimBufSpan.data(), DimCnt, TensorBufLen); + } else { + Tensor = TF_AllocateTensor(static_cast(DataType), nullptr, 0, + TensorBufLen); + } + if (unlikely(Tensor == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow] Allocate input tensor failed."sv); + return static_cast(ErrNo::Busy); + } + std::copy_n(TensorBufSpan.begin(), TensorBufLen, + static_cast(TF_TensorData(Tensor))); + + // If the old input tensor exists, delete the old one. + if (It != Cxt->Inputs.NameMap.end()) { + TF_DeleteTensor(Cxt->Inputs.DataList[TensorId]); + Cxt->Inputs.DataList[TensorId] = Tensor; + } else { + Cxt->Inputs.OperList.emplace_back(TF_Output{Operation, OperKeyPair.second}); + Cxt->Inputs.DataList.push_back(Tensor); + Cxt->Inputs.NameMap.insert({std::string(NameSV), TensorId}); + } + return static_cast(ErrNo::Success); +} + +Expect AppendOutput::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the output tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the output name buffer memory."sv) + + // Check the output operation. + auto OperKeyPair = parseIndex(NameSV); + TF_Operation *Operation = + TF_GraphOperationByName(Cxt->Graph, OperKeyPair.first.c_str()); + if (unlikely(Operation == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow] Output operation {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + + // Store names and operations if the output tensor key not exists. + auto It = Cxt->Outputs.NameMap.find(std::string(NameSV)); + if (It == Cxt->Outputs.NameMap.end()) { + uint32_t TensorId = Cxt->Outputs.DataList.size(); + Cxt->Outputs.OperList.emplace_back( + TF_Output{Operation, OperKeyPair.second}); + Cxt->Outputs.DataList.push_back(nullptr); + Cxt->Outputs.NameMap.insert({std::string(NameSV), TensorId}); + } + return static_cast(ErrNo::Success); +} + +Expect ClearInput::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Clear the inputs. + Cxt->clearInputs(); + return static_cast(ErrNo::Success); +} + +Expect ClearOutput::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Clear the outputs. + Cxt->clearOutputs(); + return static_cast(ErrNo::Success); +} + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.h b/plugins/wasmedge_tensorflow/tensorflow_func.h new file mode 100644 index 00000000..54b5e76a --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_func.h @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "tensorflow_base.h" + +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +class CreateSession : public Func { +public: + CreateSession(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModBufPtr, + uint32_t ModBufLen, uint32_t SessionIdPtr); +}; + +class CreateSessionSavedModel : public Func { +public: + CreateSessionSavedModel(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t PathPtr, + uint32_t PathLen, uint32_t TagsBufPtr, + uint32_t TagsBufLen, uint32_t SessionIdPtr); +}; + +class DeleteSession : public Func { +public: + DeleteSession(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class RunSession : public Func { +public: + RunSession(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class GetOutputTensor : public Func { +public: + GetOutputTensor(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, + uint32_t TensorIdPtr); +}; + +class GetTensorLen : public Func { +public: + GetTensorLen(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t LenPtr); +}; + +class GetTensorData : public Func { +public: + GetTensorData(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr); +}; + +class AppendInput : public Func { +public: + AppendInput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, uint32_t DimPtr, + uint32_t DimCnt, uint32_t DataType, + uint32_t TensorBufPtr, uint32_t TensorBufLen); +}; + +class AppendOutput : public Func { +public: + AppendOutput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class ClearInput : public Func { +public: + ClearInput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class ClearOutput : public Func { +public: + ClearOutput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_module.cpp b/plugins/wasmedge_tensorflow/tensorflow_module.cpp new file mode 100644 index 00000000..f0703e45 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_module.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tensorflow_module.h" +#include "tensorflow_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeTensorflowModule::WasmEdgeTensorflowModule() + : Runtime::Instance::ModuleInstance("wasmedge_tensorflow") { + addHostFunc("create_session", + std::make_unique(Env)); + addHostFunc( + "create_session_saved_model", + std::make_unique(Env)); + addHostFunc("delete_session", + std::make_unique(Env)); + addHostFunc("run_session", + std::make_unique(Env)); + addHostFunc("get_output_tensor", + std::make_unique(Env)); + addHostFunc("get_tensor_len", + std::make_unique(Env)); + addHostFunc("get_tensor_data", + std::make_unique(Env)); + addHostFunc("append_input", + std::make_unique(Env)); + addHostFunc("append_output", + std::make_unique(Env)); + addHostFunc("clear_input", + std::make_unique(Env)); + addHostFunc("clear_output", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_module.h b/plugins/wasmedge_tensorflow/tensorflow_module.h new file mode 100644 index 00000000..dfb96f9d --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_module.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "tensorflow_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeTensorflowModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeTensorflowModule(); + ~WasmEdgeTensorflowModule() = default; + + WasmEdgeTensorflow::TFEnv &getEnv() { return Env; } + +private: + WasmEdgeTensorflow::TFEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/plugins/wasmedge_tensorflowlite/CMakeLists.txt new file mode 100644 index 00000000..f8ee177d --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeTensorflowLite + SHARED + tensorflowlite_env.cpp + tensorflowlite_func.cpp + tensorflowlite_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeTensorflowLite + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeTensorflowLite + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeTensorflowLite + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeTensorflowLite + PRIVATE + wasmedge_shared + ) +endif() + +include(WASINNDeps) +wasmedge_setup_tflite_target(wasmedgePluginWasmEdgeTensorflowLite) + +install( + TARGETS wasmedgePluginWasmEdgeTensorflowLite + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h new file mode 100644 index 00000000..075a46f7 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "tensorflowlite_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +template class Func : public Runtime::HostFunction { +public: + Func(TFLiteEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + TFLiteEnv &Env; +}; + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp new file mode 100644 index 00000000..12161a6d --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tensorflowlite_env.h" +#include "tensorflowlite_module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeTensorflowLiteModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_tensorflowlite", + .Description = "Tensorflow-Lite plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 13, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_tensorflowlite", + .Description = "This module contains WasmEdge-TensorflowLite " + "host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h new file mode 100644 index 00000000..02da4069 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include "tensorflow/lite/c/c_api.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. +}; + +struct Context { + Context() = default; + ~Context() { reset(); } + void reset() noexcept { + if (Interp) { + TfLiteInterpreterDelete(Interp); + } + Interp = nullptr; + } + TfLiteInterpreter *Interp = nullptr; +}; + +struct TFLiteEnv { + TFLiteEnv() noexcept { TFLiteContext.reserve(16U); } + + Context *getContext(const uint32_t ID) noexcept { + auto It = RecycledIdx.find(ID); + if (ID < TFLiteContext.size() && It == RecycledIdx.end()) { + return &TFLiteContext[ID]; + } + return nullptr; + } + uint32_t newContext() noexcept { + uint32_t NewIdx = TFLiteContext.size(); + if (RecycledIdx.empty()) { + TFLiteContext.emplace_back(); + } else { + NewIdx = *RecycledIdx.begin(); + RecycledIdx.erase(NewIdx); + } + return NewIdx; + } + void deleteContext(const uint32_t ID) noexcept { + if (ID < TFLiteContext.size()) { + TFLiteContext[ID].reset(); + RecycledIdx.insert(ID); + } + } + +private: + std::unordered_set RecycledIdx; + std::vector TFLiteContext; +}; + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp new file mode 100644 index 00000000..48782592 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tensorflowlite_func.h" + +#include "common/span.h" +#include "common/spdlog.h" + +#include "tensorflow/lite/c/c_api.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +namespace { + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define SESSION_CHECK(Out, SessionID, Message, ErrNo) \ + auto *Out = Env.getContext(SessionID); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_SV_CHECK(OutSV, MemInst, BufPtr, BufLen, Message) \ + auto OutSV = MemInst->getStringView(BufPtr, BufLen); \ + if (unlikely(OutSV.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointer(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +} // namespace + +Expect CreateSession::body(const Runtime::CallingFrame &Frame, + uint32_t ModBufPtr, uint32_t ModBufLen, + uint32_t SessionIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input model buffer. + MEM_SPAN_CHECK(ModBufSpan, MemInst, char, ModBufPtr, ModBufLen, + "Failed when accessing the input model buffer memory."sv) + + // Check the return value: SessionIdPtr should be valid. + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessionIdPtr, + "Failed when accessing the return SessionID memory."sv) + + // Create context and import graph. + uint32_t NewID = Env.newContext(); + SESSION_CHECK(Cxt, NewID, "Failed when allocating resources."sv, + ErrNo::MissingMemory) + + auto *Model = TfLiteModelCreate(ModBufSpan.data(), ModBufLen); + if (unlikely(Model == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Cannot import TFLite model."sv); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + auto *Ops = TfLiteInterpreterOptionsCreate(); + if (unlikely(Ops == nullptr)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Cannot create TFLite interpreter options."sv); + Env.deleteContext(NewID); + TfLiteModelDelete(Model); + return static_cast(ErrNo::Busy); + } + TfLiteInterpreterOptionsSetNumThreads(Ops, 2); + Cxt->Interp = TfLiteInterpreterCreate(Model, Ops); + TfLiteInterpreterOptionsDelete(Ops); + TfLiteModelDelete(Model); + if (unlikely(Cxt->Interp == nullptr)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Cannot create TFLite interpreter."sv); + Env.deleteContext(NewID); + return static_cast(ErrNo::Busy); + } + TfLiteStatus Status = TfLiteInterpreterAllocateTensors(Cxt->Interp); + if (unlikely(Status != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Cannot allocate tensors."sv); + Env.deleteContext(NewID); + return static_cast(ErrNo::Busy); + } + + *SessionId = NewID; + return static_cast(ErrNo::Success); +} + +Expect DeleteSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + Env.deleteContext(SessionId); + return static_cast(ErrNo::Success); +} + +Expect RunSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Run session + TfLiteStatus Stat = TfLiteInterpreterInvoke(Cxt->Interp); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Invocation failed."sv); + return static_cast(ErrNo::Busy); + } + return static_cast(ErrNo::Success); +} + +Expect GetOutputTensor::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t TensorIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the output name buffer memory."sv) + + // Check the return value: TensorIdPtr should be valid. + MEM_PTR_CHECK(TensorId, MemInst, uint32_t, TensorIdPtr, + "Failed when accessing the return TensorID memory."sv) + + // Find the output tensor. + bool IsFound = false; + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(Cxt->Interp); + for (uint32_t I = 0; I < OutCnt; ++I) { + const TfLiteTensor *T = TfLiteInterpreterGetOutputTensor(Cxt->Interp, I); + if (NameSV == std::string(TfLiteTensorName(T))) { + *TensorId = I; + IsFound = true; + break; + } + } + if (unlikely(!IsFound)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Output tensor {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + return static_cast(ErrNo::Success); +} + +Expect GetTensorLen::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t LenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the return value: LenPtr should be valid. + MEM_PTR_CHECK(Len, MemInst, uint32_t, LenPtr, + "Failed when accessing the return Length memory."sv) + + // Get output tensor from ID. + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(Cxt->Interp); + if (unlikely(TensorId >= OutCnt)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Return tensor data length. + const TfLiteTensor *Tensor = + TfLiteInterpreterGetOutputTensor(Cxt->Interp, TensorId); + if (likely(Tensor != nullptr)) { + *Len = TfLiteTensorByteSize(Tensor); + } else { + *Len = 0U; + } + return static_cast(ErrNo::Success); +} + +Expect GetTensorData::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the output tensor buffer. + MEM_SPAN_CHECK( + BufSpan, MemInst, char, BufPtr, BufLen, + "Failed when accessing the output tensor write buffer memory."sv) + + // Check the return value: WrittenBytesPtr should be valid. + MEM_PTR_CHECK(WrittenBytes, MemInst, uint32_t, WrittenBytesPtr, + "Failed when accessing the return WrittenBytes memory."sv) + + // Get output tensor from ID. + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(Cxt->Interp); + if (unlikely(TensorId >= OutCnt)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Copy tensor data to buffer. + const TfLiteTensor *Tensor = + TfLiteInterpreterGetOutputTensor(Cxt->Interp, TensorId); + size_t RealSize = TfLiteTensorByteSize(Tensor); + *WrittenBytes = 0U; + if (unlikely(RealSize != BufLen)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Unexpected buffer length: {}, output tensor size: {}."sv, + BufLen, RealSize); + return static_cast(ErrNo::InvalidArgument); + } + if (likely(Tensor != nullptr)) { + TfLiteTensorCopyToBuffer(Tensor, BufSpan.data(), RealSize); + } + return static_cast(ErrNo::Success); +} + +Expect AppendInput::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t TensorBufPtr, + uint32_t TensorBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor buffer. + MEM_SPAN_CHECK(TensorBufSpan, MemInst, uint8_t, TensorBufPtr, TensorBufLen, + "Failed when accessing the input tensor buffer memory."sv) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the input name buffer memory."sv) + + // Find the input tensor. + bool IsFound = false; + uint32_t InCnt = TfLiteInterpreterGetInputTensorCount(Cxt->Interp); + for (uint32_t I = 0; I < InCnt; ++I) { + TfLiteTensor *Tensor = TfLiteInterpreterGetInputTensor(Cxt->Interp, I); + if (NameSV == std::string(TfLiteTensorName(Tensor))) { + size_t RealSize = TfLiteTensorByteSize(Tensor); + if (unlikely(RealSize != TensorBufLen)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Unexpected buffer length: {}, " + "input tensor size: {}."sv, + TensorBufLen, RealSize); + return static_cast(ErrNo::InvalidArgument); + } + TfLiteStatus Stat = TfLiteTensorCopyFromBuffer( + Tensor, TensorBufSpan.data(), TensorBufLen); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Copy data from tensor {} failed."sv, + NameSV); + return static_cast(ErrNo::Busy); + } + IsFound = true; + break; + } + } + if (unlikely(!IsFound)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Input tensor {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + + return static_cast(ErrNo::Success); +} + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h new file mode 100644 index 00000000..90e29f0b --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "tensorflowlite_base.h" + +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +class CreateSession : public Func { +public: + CreateSession(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModBufPtr, + uint32_t ModBufLen, uint32_t SessionIdPtr); +}; + +class DeleteSession : public Func { +public: + DeleteSession(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class RunSession : public Func { +public: + RunSession(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class GetOutputTensor : public Func { +public: + GetOutputTensor(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, + uint32_t TensorIdPtr); +}; + +class GetTensorLen : public Func { +public: + GetTensorLen(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t LenPtr); +}; + +class GetTensorData : public Func { +public: + GetTensorData(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr); +}; + +class AppendInput : public Func { +public: + AppendInput(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, + uint32_t TensorBufPtr, uint32_t TensorBufLen); +}; + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp new file mode 100644 index 00000000..0681849a --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tensorflowlite_module.h" +#include "tensorflowlite_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeTensorflowLiteModule::WasmEdgeTensorflowLiteModule() + : Runtime::Instance::ModuleInstance("wasmedge_tensorflowlite") { + addHostFunc("create_session", + std::make_unique(Env)); + addHostFunc("delete_session", + std::make_unique(Env)); + addHostFunc("run_session", + std::make_unique(Env)); + addHostFunc("get_output_tensor", + std::make_unique(Env)); + addHostFunc("get_tensor_len", + std::make_unique(Env)); + addHostFunc("get_tensor_data", + std::make_unique(Env)); + addHostFunc("append_input", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h new file mode 100644 index 00000000..1f5161b7 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "tensorflowlite_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeTensorflowLiteModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeTensorflowLiteModule(); + ~WasmEdgeTensorflowLiteModule() = default; + + WasmEdgeTensorflowLite::TFLiteEnv &getEnv() { return Env; } + +private: + WasmEdgeTensorflowLite::TFLiteEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/CMakeLists.txt b/plugins/wasmedge_zlib/CMakeLists.txt new file mode 100644 index 00000000..56745021 --- /dev/null +++ b/plugins/wasmedge_zlib/CMakeLists.txt @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +find_package(ZLIB REQUIRED) + +set(ZLIB_COMPAT ON) + +wasmedge_add_library(wasmedgePluginWasmEdgeZlib + SHARED + zlibenv.cpp + zlibfunc.cpp + zlibmodule.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeZlib + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeZlib + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeZlib + PRIVATE + wasmedgeCAPI + z + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeZlib + PRIVATE + wasmedge_shared + z + ) +endif() + +install( + TARGETS wasmedgePluginWasmEdgeZlib + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_zlib/zlibbase.h b/plugins/wasmedge_zlib/zlibbase.h new file mode 100644 index 00000000..63b9a16e --- /dev/null +++ b/plugins/wasmedge_zlib/zlibbase.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "zlibenv.h" + +#include "common/errcode.h" +#include "runtime/callingframe.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasmEdgeZlib : public Runtime::HostFunction { +public: + WasmEdgeZlib(WasmEdgeZlibEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgeZlibEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibenv.cpp b/plugins/wasmedge_zlib/zlibenv.cpp new file mode 100644 index 00000000..f3e8eaa4 --- /dev/null +++ b/plugins/wasmedge_zlib/zlibenv.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "zlibenv.h" +#include "zlibmodule.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeZlibModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_zlib", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_zlib", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibenv.h b/plugins/wasmedge_zlib/zlibenv.h new file mode 100644 index 00000000..407772ff --- /dev/null +++ b/plugins/wasmedge_zlib/zlibenv.h @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include +#include +#include + +#include + +/** + * @brief A struct that maps exactly to a 32-bit Wasm z_stream object + * + */ +struct WasmZStream { + /* [Wasm Offset] next input byte */ + uint32_t NextIn; + /* number of bytes available at next_in */ + uint32_t AvailIn; + /* total number of input bytes read so far */ + uint32_t TotalIn; + + /* [Wasm Offset] next output byte will go here */ + uint32_t NextOut; + /* remaining free space at next_out */ + uint32_t AvailOut; + /* total number of bytes output so far */ + uint32_t TotalOut; + + /* [Wasm Offset] last error message, NULL if no error */ + uint32_t Msg; + /* [Wasm Offset] not visible by applications */ + uint32_t State; + + /* used to allocate the internal state */ + uint32_t Zalloc; + /* used to free the internal state */ + uint32_t Zfree; + /* [Wasm Offset] private data object passed to zalloc and zfree */ + uint32_t Opaque; + + /* best guess about the data type: binary or text for deflate, or the decoding + state for inflate */ + int32_t DataType; + + /* Adler-32 or CRC-32 value of the uncompressed data */ + uint32_t Adler; + /* reserved for future use */ + uint32_t Reserved; +}; +static_assert(sizeof(WasmZStream) == 56, "WasmZStream should be 56 bytes"); + +/* + gzip header information passed to and from zlib routines. See RFC 1952 for + more details on the meanings of these fields. +*/ +struct WasmGZHeader { + int32_t Text; /* true if compressed data believed to be text */ + uint32_t Time; /* modification time */ + int32_t XFlags; /* extra flags (not used when writing a gzip file) */ + int32_t OS; /* operating system */ + uint32_t Extra; /* pointer to extra field or Z_NULL if none */ + uint32_t ExtraLen; /* extra field length (valid if extra != Z_NULL) */ + uint32_t ExtraMax; /* space at extra (only when reading header) */ + uint32_t Name; /* pointer to zero-terminated file name or Z_NULL */ + uint32_t NameMax; /* space at name (only when reading header) */ + uint32_t Comment; /* pointer to zero-terminated comment or Z_NULL */ + uint32_t CommMax; /* space at comment (only when reading header) */ + int32_t HCRC; /* true if there was or will be a header crc */ + int32_t Done; /* true when done reading gzip header (not used + when writing a gzip file) */ +}; +static_assert(sizeof(WasmGZHeader) == 52, "WasmGZHeader should be 52 bytes"); + +namespace WasmEdge { +namespace Host { + +class WasmEdgeZlibEnvironment { +public: + using GZFile = std::remove_pointer_t; + + struct GZStore { + uint32_t WasmGZHeaderOffset; + std::unique_ptr HostGZHeader; + }; + + std::unordered_map> ZStreamMap; + std::map, std::greater> GZFileMap; + std::unordered_map GZHeaderMap; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibfunc.cpp b/plugins/wasmedge_zlib/zlibfunc.cpp new file mode 100644 index 00000000..b283f69a --- /dev/null +++ b/plugins/wasmedge_zlib/zlibfunc.cpp @@ -0,0 +1,1364 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "zlibfunc.h" + +#include + +namespace WasmEdge { +namespace Host { + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Zlib] Memory instance not found."sv); \ + return Unexpect(ErrCode::Value::HostFuncError); \ + } + +constexpr bool CheckSize(int32_t StreamSize) { + + return (StreamSize == static_cast(sizeof(WasmZStream))); +} + +static constexpr uint32_t WasmGZFileStart = sizeof(gzFile); + +template +auto SyncRun(const std::string_view &Msg, WasmEdgeZlibEnvironment &Env, + uint32_t ZStreamPtr, const Runtime::CallingFrame &Frame, + T Callback) -> Expect { + + MEMINST_CHECK(MemInst, Frame, 0) + WasmZStream *ModuleZStream = MemInst->getPointer(ZStreamPtr); + + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [{}-SyncRun] "sv + "Invalid ZStreamPtr received."sv, + Msg); + return Unexpect(ErrCode::Value::HostFuncError); + } + auto HostZStream = HostZStreamIt->second.get(); + const auto GZHeaderStoreIt = Env.GZHeaderMap.find(ZStreamPtr); + + HostZStream->next_in = + MemInst->getPointer(ModuleZStream->NextIn); + HostZStream->avail_in = ModuleZStream->AvailIn; + HostZStream->total_in = ModuleZStream->TotalIn; + + HostZStream->next_out = + MemInst->getPointer(ModuleZStream->NextOut); + HostZStream->avail_out = ModuleZStream->AvailOut; + HostZStream->total_out = ModuleZStream->TotalOut; + + // TODO: ignore msg for now + // ignore state + // ignore zalloc, zfree, opaque + + HostZStream->data_type = ModuleZStream->DataType; + HostZStream->adler = ModuleZStream->Adler; + HostZStream->reserved = ModuleZStream->Reserved; + + const auto PreComputeNextIn = HostZStream->next_in; + const auto PreComputeNextOut = HostZStream->next_out; + + unsigned char *PreComputeExtra{}; + unsigned char *PreComputeName{}; + unsigned char *PreComputeComment{}; + + if (GZHeaderStoreIt != Env.GZHeaderMap.end()) { + // Sync GZ Header + + auto *ModuleGZHeader = MemInst->getPointer( + GZHeaderStoreIt->second.WasmGZHeaderOffset); + auto *HostGZHeader = GZHeaderStoreIt->second.HostGZHeader.get(); + + HostGZHeader->text = ModuleGZHeader->Text; + HostGZHeader->time = ModuleGZHeader->Time; + HostGZHeader->xflags = ModuleGZHeader->XFlags; + HostGZHeader->os = ModuleGZHeader->OS; + + HostGZHeader->extra = + MemInst->getPointer(ModuleGZHeader->Extra); + HostGZHeader->extra_len = ModuleGZHeader->ExtraLen; + HostGZHeader->extra_max = ModuleGZHeader->ExtraMax; + + HostGZHeader->name = + MemInst->getPointer(ModuleGZHeader->Name); + HostGZHeader->name_max = ModuleGZHeader->NameMax; + + HostGZHeader->comment = + MemInst->getPointer(ModuleGZHeader->Comment); + HostGZHeader->comm_max = ModuleGZHeader->CommMax; + + HostGZHeader->hcrc = ModuleGZHeader->HCRC; + HostGZHeader->done = ModuleGZHeader->Done; + + PreComputeExtra = HostGZHeader->extra; + PreComputeName = HostGZHeader->name; + PreComputeComment = HostGZHeader->comment; + } + + const auto ZRes = Callback(HostZStream); + + ModuleZStream->NextIn += HostZStream->next_in - PreComputeNextIn; + ModuleZStream->AvailIn = HostZStream->avail_in; + ModuleZStream->TotalIn = HostZStream->total_in; + + ModuleZStream->NextOut += HostZStream->next_out - PreComputeNextOut; + ModuleZStream->AvailOut = HostZStream->avail_out; + ModuleZStream->TotalOut = HostZStream->total_out; + + // TODO: ignore msg for now + // ignore state + // ignore zalloc, zfree, opaque + + ModuleZStream->DataType = HostZStream->data_type; + ModuleZStream->Adler = HostZStream->adler; + ModuleZStream->Reserved = HostZStream->reserved; + + if (GZHeaderStoreIt != Env.GZHeaderMap.end()) { + // Sync GZ Header + + auto *ModuleGZHeader = MemInst->getPointer( + GZHeaderStoreIt->second.WasmGZHeaderOffset); + auto *HostGZHeader = GZHeaderStoreIt->second.HostGZHeader.get(); + + ModuleGZHeader->Text = HostGZHeader->text; + ModuleGZHeader->Time = HostGZHeader->time; + ModuleGZHeader->XFlags = HostGZHeader->xflags; + ModuleGZHeader->OS = HostGZHeader->os; + + ModuleGZHeader->Extra += HostGZHeader->extra - PreComputeExtra; + ModuleGZHeader->ExtraLen = HostGZHeader->extra_len; + ModuleGZHeader->ExtraMax = HostGZHeader->extra_max; + + ModuleGZHeader->Name += HostGZHeader->name - PreComputeName; + ModuleGZHeader->NameMax = HostGZHeader->name_max; + + ModuleGZHeader->Comment += HostGZHeader->comment - PreComputeComment; + ModuleGZHeader->CommMax = HostGZHeader->comm_max; + + ModuleGZHeader->HCRC = HostGZHeader->hcrc; + ModuleGZHeader->Done = HostGZHeader->done; + } + + return ZRes; +} + +Expect +WasmEdgeZlibDeflateInit::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Level) { + + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun( + "WasmEdgeZlibDeflateInit", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflateInit(HostZStream, Level); }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibDeflate::WasmEdgeZlibDeflate::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t Flush) { + + return SyncRun( + "WasmEdgeZlibDeflate", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflate(HostZStream, Flush); }); +} + +Expect WasmEdgeZlibDeflateEnd::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateEnd", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflateEnd(HostZStream); }); + + if (ZRes == Z_OK) + Env.ZStreamMap.erase(ZStreamPtr); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateInit", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateInit(HostZStream); }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibInflate::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Flush) { + + return SyncRun( + "WasmEdgeZlibInflate", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflate(HostZStream, Flush); }); +} + +Expect WasmEdgeZlibInflateEnd::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateEnd", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateEnd(HostZStream); }); + + Env.ZStreamMap.erase(ZStreamPtr); + + return ZRes; +} + +Expect WasmEdgeZlibDeflateInit2::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t Level, + int32_t Method, int32_t WindowBits, int32_t MemLevel, int32_t Strategy) { + + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateInit2", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateInit2(HostZStream, Level, Method, WindowBits, + MemLevel, Strategy); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibDeflateSetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength) { + + MEMINST_CHECK(MemInst, Frame, 0) + + const auto *Dictionary = MemInst->getPointer(DictionaryPtr); + + return SyncRun("WasmEdgeZlibDeflateSetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateSetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +Expect WasmEdgeZlibDeflateGetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr) { + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Dictionary = MemInst->getPointer(DictionaryPtr); + auto *DictLength = MemInst->getPointer(DictLengthPtr); + + return SyncRun("WasmEdgeZlibDeflateGetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateGetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +/* +"The deflateCopy() function shall copy the compression state information in +source to the uninitialized z_stream structure referenced by dest." + +https://refspecs.linuxbase.org/LSB_3.0.0/LSB-Core-generic/LSB-Core-generic/zlib-deflatecopy-1.html +*/ +Expect +WasmEdgeZlibDeflateCopy::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, uint32_t SourcePtr) { + const auto SourceZStreamIt = Env.ZStreamMap.find(SourcePtr); + if (SourceZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateCopy] "sv + "Invalid SourcePtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto DestZStream = std::make_unique(); + + auto It = + Env.ZStreamMap.emplace(std::make_pair(DestPtr, std::move(DestZStream))) + .second; + + const auto Res = SyncRun("WasmEdgeZlibDeflateCopy", Env, DestPtr, Frame, + [&](z_stream *) { return 0; }); + if (!Res.has_value()) + return Res; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateCopy", Env, DestPtr, Frame, + [&](z_stream *DestZStream) { + return deflateCopy(DestZStream, SourceZStreamIt->second.get()); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibDeflateReset::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibDeflateReset", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflateReset(HostZStream); }); +} + +Expect +WasmEdgeZlibDeflateParams::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Level, + int32_t Strategy) { + + return SyncRun("WasmEdgeZlibDeflateParams", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateParams(HostZStream, Level, Strategy); + }); +} + +Expect WasmEdgeZlibDeflateTune::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t GoodLength, + int32_t MaxLazy, int32_t NiceLength, int32_t MaxChain) { + + return SyncRun("WasmEdgeZlibDeflateTune", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateTune(HostZStream, GoodLength, MaxLazy, + NiceLength, MaxChain); + }); +} + +Expect +WasmEdgeZlibDeflateBound::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t SourceLen) { + + return SyncRun("WasmEdgeZlibDeflateBound", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateBound(HostZStream, SourceLen); + }); +} + +Expect +WasmEdgeZlibDeflatePending::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t PendingPtr, + uint32_t BitsPtr) { + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Pending = MemInst->getPointer(PendingPtr); + auto *Bits = MemInst->getPointer(BitsPtr); + + return SyncRun("WasmEdgeZlibDeflatePending", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflatePending(HostZStream, Pending, Bits); + }); +} + +Expect +WasmEdgeZlibDeflatePrime::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Bits, + int32_t Value) { + + return SyncRun("WasmEdgeZlibDeflatePrime", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflatePrime(HostZStream, Bits, Value); + }); +} + +Expect +WasmEdgeZlibDeflateSetHeader::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t HeadPtr) { + + auto HostGZHeader = std::make_unique(); + auto HostGZHeaderPtr = HostGZHeader.get(); + + auto It = Env.GZHeaderMap + .emplace(std::pair{ + ZStreamPtr, + WasmEdgeZlibEnvironment::GZStore{ + .WasmGZHeaderOffset = HeadPtr, + .HostGZHeader = std::move(HostGZHeader)}}) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateSetHeader", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateSetHeader(HostZStream, HostGZHeaderPtr); + }); + + if (ZRes != Z_OK) + Env.GZHeaderMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit2::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits) { + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun("WasmEdgeZlibInflateInit2", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateInit2(HostZStream, WindowBits); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibInflateSetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength) { + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Dictionary = MemInst->getPointer(DictionaryPtr); + + return SyncRun("WasmEdgeZlibInflateSetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateSetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +Expect WasmEdgeZlibInflateGetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr) { + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Dictionary = MemInst->getPointer(DictionaryPtr); + auto *DictLength = MemInst->getPointer(DictLengthPtr); + + return SyncRun("WasmEdgeZlibInflateGetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateGetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +Expect +WasmEdgeZlibInflateSync::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibInflateSync", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateSync(HostZStream); }); +} + +Expect +WasmEdgeZlibInflateCopy::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, uint32_t SourcePtr) { + const auto SourceZStreamIt = Env.ZStreamMap.find(SourcePtr); + if (SourceZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateCopy] "sv + "Invalid SourcePtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto DestZStream = std::make_unique(); + + auto It = + Env.ZStreamMap.emplace(std::make_pair(DestPtr, std::move(DestZStream))) + .second; + + const auto Res = SyncRun("WasmEdgeZlibInflateCopy", Env, DestPtr, Frame, + [&](z_stream *) { return 0; }); + if (!Res.has_value()) + return Res; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateCopy", Env, DestPtr, Frame, + [&](z_stream *DestZStream) { + return inflateCopy(DestZStream, SourceZStreamIt->second.get()); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateReset::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibInflateReset", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateReset(HostZStream); }); +} + +Expect +WasmEdgeZlibInflateReset2::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits) { + + return SyncRun("WasmEdgeZlibInflateReset2", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateReset2(HostZStream, WindowBits); + }); +} + +Expect +WasmEdgeZlibInflatePrime::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Bits, + int32_t Value) { + + return SyncRun("WasmEdgeZlibInflatePrime", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflatePrime(HostZStream, Bits, Value); + }); +} + +Expect +WasmEdgeZlibInflateMark::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibInflateMark", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateMark(HostZStream); }); +} + +Expect +WasmEdgeZlibInflateGetHeader::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t HeadPtr) { + + auto HostGZHeader = std::make_unique(); + auto HostGZHeaderPtr = HostGZHeader.get(); + + auto It = Env.GZHeaderMap + .emplace(std::pair{ + ZStreamPtr, + WasmEdgeZlibEnvironment::GZStore{ + .WasmGZHeaderOffset = HeadPtr, + .HostGZHeader = std::move(HostGZHeader)}}) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateGetHeader", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateGetHeader(HostZStream, HostGZHeaderPtr); + }); + + if (ZRes != Z_OK) + Env.GZHeaderMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateBackInit::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits, + uint32_t WindowPtr) { + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Window = MemInst->getPointer(WindowPtr); + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateBackInit", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateBackInit(HostZStream, WindowBits, Window); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateBackEnd::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + const auto ZRes = SyncRun( + "WasmEdgeZlibInflateBackEnd", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateBackEnd(HostZStream); }); + + Env.ZStreamMap.erase(ZStreamPtr); + + return ZRes; +} + +Expect +WasmEdgeZlibZlibCompilerFlags::body(const Runtime::CallingFrame &) { + return zlibCompileFlags(); +} + +Expect WasmEdgeZlibCompress::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, + uint32_t DestLenPtr, + uint32_t SourcePtr, + uint32_t SourceLen) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + + unsigned long HostDestLen; + HostDestLen = *DestLen; + const auto ZRes = compress(Dest, &HostDestLen, Source, SourceLen); + *DestLen = HostDestLen; + + return ZRes; +} + +Expect WasmEdgeZlibCompress2::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, + uint32_t DestLenPtr, + uint32_t SourcePtr, + uint32_t SourceLen, int32_t Level) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + + unsigned long HostDestLen; + HostDestLen = *DestLen; + const auto ZRes = compress2(Dest, &HostDestLen, Source, SourceLen, Level); + *DestLen = HostDestLen; + + return ZRes; +} + +Expect WasmEdgeZlibCompressBound::body(const Runtime::CallingFrame &, + uint32_t SourceLen) { + return compressBound(SourceLen); +} + +Expect WasmEdgeZlibUncompress::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, + uint32_t DestLenPtr, + uint32_t SourcePtr, + uint32_t SourceLen) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + + unsigned long HostDestLen; + HostDestLen = *DestLen; + const auto ZRes = uncompress(Dest, &HostDestLen, Source, SourceLen); + *DestLen = HostDestLen; + + return ZRes; +} + +Expect +WasmEdgeZlibUncompress2::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, uint32_t DestLenPtr, + uint32_t SourcePtr, uint32_t SourceLenPtr) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + auto *SourceLen = MemInst->getPointer(SourceLenPtr); + + unsigned long HostDestLen, HostSourceLen; + HostDestLen = *DestLen; + HostSourceLen = *SourceLen; + const auto ZRes = uncompress2(Dest, &HostDestLen, Source, &HostSourceLen); + *DestLen = HostDestLen; + *SourceLen = HostSourceLen; + + return ZRes; +} + +Expect WasmEdgeZlibGZOpen::body(const Runtime::CallingFrame &Frame, + uint32_t PathPtr, uint32_t ModePtr) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Path = MemInst->getPointer(PathPtr); + auto *Mode = MemInst->getPointer(ModePtr); + + auto ZRes = gzopen(Path, Mode); + + const auto NewWasmGZFile = WasmGZFileStart + Env.GZFileMap.size(); + auto El = + std::pair>( + NewWasmGZFile, ZRes); + + Env.GZFileMap.emplace(std::move(El)); + + return NewWasmGZFile; +} + +Expect WasmEdgeZlibGZDOpen::body(const Runtime::CallingFrame &Frame, + int32_t FD, uint32_t ModePtr) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Mode = MemInst->getPointer(ModePtr); + + auto ZRes = gzdopen(FD, Mode); + + const auto NewWasmGZFile = WasmGZFileStart + Env.GZFileMap.size(); + auto El = + std::pair>( + NewWasmGZFile, ZRes); + + Env.GZFileMap.emplace(std::move(El)); + + return NewWasmGZFile; +} + +Expect WasmEdgeZlibGZBuffer::body(const Runtime::CallingFrame &, + uint32_t GZFile, uint32_t Size) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZBuffer] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzbuffer(GZFileIt->second.get(), Size); +} + +Expect WasmEdgeZlibGZSetParams::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t Level, + int32_t Strategy) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZSetParams] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzsetparams(GZFileIt->second.get(), Level, Strategy); +} + +Expect WasmEdgeZlibGZRead::body(const Runtime::CallingFrame &Frame, + uint32_t GZFile, uint32_t BufPtr, + uint32_t Len) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZRead] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzread(GZFileIt->second.get(), Buf, Len); +} + +Expect WasmEdgeZlibGZFread::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t Size, + uint32_t NItems, uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFread] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzfread(Buf, Size, NItems, GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZWrite::body(const Runtime::CallingFrame &Frame, + uint32_t GZFile, uint32_t BufPtr, + uint32_t Len) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZWrite] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzwrite(GZFileIt->second.get(), Buf, Len); +} + +Expect WasmEdgeZlibGZFwrite::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t Size, + uint32_t NItems, uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFwrite] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzfwrite(Buf, Size, NItems, GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZPuts::body(const Runtime::CallingFrame &Frame, + uint32_t GZFile, uint32_t StringPtr) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZPuts] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + MEMINST_CHECK(MemInst, Frame, 0) + + auto *String = MemInst->getPointer(StringPtr); + + return gzputs(GZFileIt->second.get(), String); +} + +Expect WasmEdgeZlibGZPutc::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t C) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZPutc] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzputc(GZFileIt->second.get(), C); +} + +Expect WasmEdgeZlibGZGetc::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZGetc] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzgetc(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZUngetc::body(const Runtime::CallingFrame &, + int32_t C, uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZUngetc] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzungetc(C, GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZFlush::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t Flush) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFlush] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzflush(GZFileIt->second.get(), Flush); +} + +Expect WasmEdgeZlibGZSeek::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t Offset, + int32_t Whence) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZSeek] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzseek(GZFileIt->second.get(), Offset, Whence); +} + +Expect WasmEdgeZlibGZRewind::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZRewind] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzrewind(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZTell::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZTell] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gztell(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZOffset::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZOffset] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzoffset(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZEof::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZEof] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzeof(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZDirect::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZDirect] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzdirect(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZClose::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClose] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto ZRes = gzclose(GZFileIt->second.get()); + + Env.GZFileMap.erase(GZFileIt); + + return ZRes; +} + +Expect WasmEdgeZlibGZClose_r::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClose_r] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto ZRes = gzclose_r(GZFileIt->second.get()); + + Env.GZFileMap.erase(GZFileIt); + + return ZRes; +} + +Expect WasmEdgeZlibGZClose_w::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClose_w] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto ZRes = gzclose_w(GZFileIt->second.get()); + + Env.GZFileMap.erase(GZFileIt); + + return ZRes; +} + +Expect WasmEdgeZlibGZClearerr::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClearerr] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + gzclearerr(GZFileIt->second.get()); + + return Expect{}; +} + +Expect WasmEdgeZlibAdler32::body(const Runtime::CallingFrame &Frame, + uint32_t Adler, uint32_t BufPtr, + uint32_t Len) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return adler32(Adler, Buf, Len); +} + +Expect WasmEdgeZlibAdler32_z::body(const Runtime::CallingFrame &Frame, + uint32_t Adler, uint32_t BufPtr, + uint32_t Len) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return adler32_z(Adler, Buf, Len); +} + +Expect WasmEdgeZlibAdler32Combine::body(const Runtime::CallingFrame &, + uint32_t Adler1, + uint32_t Adler2, + int32_t Len2) { + return adler32_combine(Adler1, Adler2, Len2); +} + +Expect WasmEdgeZlibCRC32::body(const Runtime::CallingFrame &Frame, + uint32_t CRC, uint32_t BufPtr, + uint32_t Len) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return crc32(CRC, Buf, Len); +} + +Expect WasmEdgeZlibCRC32_z::body(const Runtime::CallingFrame &Frame, + uint32_t CRC, uint32_t BufPtr, + uint32_t Len) { + MEMINST_CHECK(MemInst, Frame, 0) + + auto *Buf = MemInst->getPointer(BufPtr); + + return crc32_z(CRC, Buf, Len); +} + +Expect WasmEdgeZlibCRC32Combine::body(const Runtime::CallingFrame &, + uint32_t CRC1, uint32_t CRC2, + int32_t Len2) { + return crc32_combine(CRC1, CRC2, Len2); +} + +Expect +WasmEdgeZlibDeflateInit_::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Level, + uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + MEMINST_CHECK(MemInst, Frame, 0) + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + + // ignore wasm custom allocators + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + // Ignore opaque because zmalloc and zfree are ignored. + HostZStream->opaque = Z_NULL; + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateInit_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateInit_(HostZStream, Level, WasmZlibVersion, + sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit_::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t VersionPtr, + int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + MEMINST_CHECK(MemInst, Frame, 0) + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + + // ignore wasm custom allocators + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + // Ignore opaque because zmalloc and zfree are ignored. + HostZStream->opaque = Z_NULL; + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun("WasmEdgeZlibInflateInit_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateInit_(HostZStream, WasmZlibVersion, + sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibDeflateInit2_::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t Level, + int32_t Method, int32_t WindowBits, int32_t MemLevel, int32_t Strategy, + uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + MEMINST_CHECK(MemInst, Frame, 0) + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun( + "WasmEdgeZlibDeflateInit2_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateInit2_(HostZStream, Level, Method, WindowBits, MemLevel, + Strategy, WasmZlibVersion, sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit2_::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits, + uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + MEMINST_CHECK(MemInst, Frame, 0) + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateInit2_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateInit2_(HostZStream, WindowBits, WasmZlibVersion, + sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibInflateBackInit_::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t WindowBits, + uint32_t WindowPtr, uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + MEMINST_CHECK(MemInst, Frame, 0) + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto *Window = MemInst->getPointer(WindowPtr); + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateBackInit_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateBackInit_(HostZStream, WindowBits, Window, + WasmZlibVersion, sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibGZGetc_::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZGetc_] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzgetc_(GZFileIt->second.get()); +} + +Expect +WasmEdgeZlibInflateSyncPoint::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateSyncPoint] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateSyncPoint(HostZStreamIt->second.get()); +} + +Expect +WasmEdgeZlibInflateUndermine::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr, int32_t Subvert) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateUndermine] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateUndermine(HostZStreamIt->second.get(), Subvert); +} + +Expect WasmEdgeZlibInflateValidate::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr, + int32_t Check) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateValidate] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateValidate(HostZStreamIt->second.get(), Check); +} + +Expect +WasmEdgeZlibInflateCodesUsed::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateCodesUsed] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateCodesUsed(HostZStreamIt->second.get()); +} + +Expect +WasmEdgeZlibInflateResetKeep::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateResetKeep] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateResetKeep(HostZStreamIt->second.get()); +} + +Expect +WasmEdgeZlibDeflateResetKeep::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateResetKeep] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return deflateResetKeep(HostZStreamIt->second.get()); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibfunc.h b/plugins/wasmedge_zlib/zlibfunc.h new file mode 100644 index 00000000..21053cbb --- /dev/null +++ b/plugins/wasmedge_zlib/zlibfunc.h @@ -0,0 +1,631 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "zlibbase.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeZlibDeflateInit : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level); +}; + +class WasmEdgeZlibDeflate : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflate(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Flush); +}; + +class WasmEdgeZlibDeflateEnd : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateEnd(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateInit : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflate : public WasmEdgeZlib { +public: + WasmEdgeZlibInflate(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Flush); +}; + +class WasmEdgeZlibInflateEnd : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateEnd(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibDeflateInit2 : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, int32_t Method, int32_t WindowBits, + int32_t MemLevel, int32_t Strategy); +}; + +class WasmEdgeZlibDeflateSetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateSetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength); +}; + +class WasmEdgeZlibDeflateGetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateGetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr); +}; + +class WasmEdgeZlibDeflateCopy : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateCopy(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t SourcePtr); +}; + +class WasmEdgeZlibDeflateReset : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateReset(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibDeflateParams + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateParams(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, int32_t Strategy); +}; + +class WasmEdgeZlibDeflateTune : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateTune(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t GoodLength, int32_t MaxLazy, int32_t NiceLength, + int32_t MaxChain); +}; + +// https://github.com/emscripten-core/emscripten/issues/17009 +// Use 32-bit because long is 32-bit wide on the Wasm side. +class WasmEdgeZlibDeflateBound : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateBound(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t SourceLen); +}; + +class WasmEdgeZlibDeflatePending + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflatePending(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t PendingPtr, uint32_t BitsPtr); +}; + +class WasmEdgeZlibDeflatePrime : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflatePrime(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Bits, int32_t Value); +}; + +class WasmEdgeZlibDeflateSetHeader + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateSetHeader(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t HeadPtr); +}; + +class WasmEdgeZlibInflateInit2 : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits); +}; + +class WasmEdgeZlibInflateSetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateSetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength); +}; + +class WasmEdgeZlibInflateGetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateGetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr); +}; + +class WasmEdgeZlibInflateSync : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateSync(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateCopy : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateCopy(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t SourcePtr); +}; + +class WasmEdgeZlibInflateReset : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateReset(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateReset2 + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateReset2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits); +}; + +class WasmEdgeZlibInflatePrime : public WasmEdgeZlib { +public: + WasmEdgeZlibInflatePrime(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Bits, int32_t Value); +}; + +class WasmEdgeZlibInflateMark : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateMark(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateGetHeader + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateGetHeader(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t HeadPtr); +}; + +class WasmEdgeZlibInflateBackInit + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateBackInit(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits, uint32_t WindowPtr); +}; + +class WasmEdgeZlibInflateBackEnd + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateBackEnd(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibZlibCompilerFlags + : public WasmEdgeZlib { +public: + WasmEdgeZlibZlibCompilerFlags(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class WasmEdgeZlibCompress : public WasmEdgeZlib { +public: + WasmEdgeZlibCompress(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLen); +}; + +class WasmEdgeZlibCompress2 : public WasmEdgeZlib { +public: + WasmEdgeZlibCompress2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLen, int32_t Level); +}; + +class WasmEdgeZlibCompressBound + : public WasmEdgeZlib { +public: + WasmEdgeZlibCompressBound(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SourceLen); +}; + +class WasmEdgeZlibUncompress : public WasmEdgeZlib { +public: + WasmEdgeZlibUncompress(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLen); +}; + +class WasmEdgeZlibUncompress2 : public WasmEdgeZlib { +public: + WasmEdgeZlibUncompress2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLenPtr); +}; + +class WasmEdgeZlibGZOpen : public WasmEdgeZlib { +public: + WasmEdgeZlibGZOpen(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t PathPtr, + uint32_t ModePtr); +}; + +class WasmEdgeZlibGZDOpen : public WasmEdgeZlib { +public: + WasmEdgeZlibGZDOpen(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t FD, + uint32_t ModePtr); +}; + +class WasmEdgeZlibGZBuffer : public WasmEdgeZlib { +public: + WasmEdgeZlibGZBuffer(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t Size); +}; + +class WasmEdgeZlibGZSetParams : public WasmEdgeZlib { +public: + WasmEdgeZlibGZSetParams(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t Level, int32_t Strategy); +}; + +class WasmEdgeZlibGZRead : public WasmEdgeZlib { +public: + WasmEdgeZlibGZRead(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibGZFread : public WasmEdgeZlib { +public: + WasmEdgeZlibGZFread(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, + uint32_t Size, uint32_t NItems, uint32_t GZFile); +}; + +class WasmEdgeZlibGZWrite : public WasmEdgeZlib { +public: + WasmEdgeZlibGZWrite(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibGZFwrite : public WasmEdgeZlib { +public: + WasmEdgeZlibGZFwrite(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, + uint32_t Size, uint32_t NItems, uint32_t GZFile); +}; + +class WasmEdgeZlibGZPuts : public WasmEdgeZlib { +public: + WasmEdgeZlibGZPuts(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t StringPtr); +}; + +class WasmEdgeZlibGZPutc : public WasmEdgeZlib { +public: + WasmEdgeZlibGZPutc(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t C); +}; + +class WasmEdgeZlibGZGetc : public WasmEdgeZlib { +public: + WasmEdgeZlibGZGetc(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZUngetc : public WasmEdgeZlib { +public: + WasmEdgeZlibGZUngetc(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t C, + uint32_t GZFile); +}; + +class WasmEdgeZlibGZFlush : public WasmEdgeZlib { +public: + WasmEdgeZlibGZFlush(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t Flush); +}; + +// z_off_t --> long +class WasmEdgeZlibGZSeek : public WasmEdgeZlib { +public: + WasmEdgeZlibGZSeek(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t Offset, int32_t Whence); +}; + +class WasmEdgeZlibGZRewind : public WasmEdgeZlib { +public: + WasmEdgeZlibGZRewind(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZTell : public WasmEdgeZlib { +public: + WasmEdgeZlibGZTell(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZOffset : public WasmEdgeZlib { +public: + WasmEdgeZlibGZOffset(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZEof : public WasmEdgeZlib { +public: + WasmEdgeZlibGZEof(WasmEdgeZlibEnvironment &HostEnv) : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZDirect : public WasmEdgeZlib { +public: + WasmEdgeZlibGZDirect(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClose : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClose(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClose_r : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClose_r(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClose_w : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClose_w(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClearerr : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClearerr(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibAdler32 : public WasmEdgeZlib { +public: + WasmEdgeZlibAdler32(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Adler, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibAdler32_z : public WasmEdgeZlib { +public: + WasmEdgeZlibAdler32_z(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Adler, + uint32_t BufPtr, uint32_t Len); +}; + +// z_off_t --> long +class WasmEdgeZlibAdler32Combine + : public WasmEdgeZlib { +public: + WasmEdgeZlibAdler32Combine(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Adler1, + uint32_t Adler2, int32_t Len2); +}; + +class WasmEdgeZlibCRC32 : public WasmEdgeZlib { +public: + WasmEdgeZlibCRC32(WasmEdgeZlibEnvironment &HostEnv) : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t CRC, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibCRC32_z : public WasmEdgeZlib { +public: + WasmEdgeZlibCRC32_z(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t CRC, + uint32_t BufPtr, uint32_t Len); +}; + +// z_off_t --> long +class WasmEdgeZlibCRC32Combine : public WasmEdgeZlib { +public: + WasmEdgeZlibCRC32Combine(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t CRC1, + uint32_t CRC2, int32_t Len2); +}; + +class WasmEdgeZlibDeflateInit_ : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, uint32_t VersionPtr, int32_t StreamSize); +}; + +class WasmEdgeZlibInflateInit_ : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t VersionPtr, int32_t StreamSize); +}; + +class WasmEdgeZlibDeflateInit2_ + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit2_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, int32_t Method, int32_t WindowBits, + int32_t MemLevel, int32_t Strategy, uint32_t VersionPtr, + int32_t StreamSize); +}; + +class WasmEdgeZlibInflateInit2_ + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit2_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits, uint32_t VersionPtr, + int32_t StreamSize); +}; + +class WasmEdgeZlibInflateBackInit_ + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateBackInit_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits, uint32_t WindowPtr, + uint32_t VersionPtr, int32_t StreamSize); +}; + +class WasmEdgeZlibGZGetc_ : public WasmEdgeZlib { +public: + WasmEdgeZlibGZGetc_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibInflateSyncPoint + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateSyncPoint(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateUndermine + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateUndermine(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Subvert); +}; + +class WasmEdgeZlibInflateValidate + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateValidate(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Check); +}; + +class WasmEdgeZlibInflateCodesUsed + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateCodesUsed(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateResetKeep + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateResetKeep(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibDeflateResetKeep + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateResetKeep(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibmodule.cpp b/plugins/wasmedge_zlib/zlibmodule.cpp new file mode 100644 index 00000000..4e39eaa6 --- /dev/null +++ b/plugins/wasmedge_zlib/zlibmodule.cpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "zlibmodule.h" +#include "zlibfunc.h" + +namespace WasmEdge { +namespace Host { + +/// Register your functions in module. +WasmEdgeZlibModule::WasmEdgeZlibModule() : ModuleInstance("wasmedge_zlib") { + addHostFunc("deflateInit", std::make_unique(Env)); + addHostFunc("deflate", std::make_unique(Env)); + addHostFunc("deflateEnd", std::make_unique(Env)); + addHostFunc("inflateInit", std::make_unique(Env)); + addHostFunc("inflate", std::make_unique(Env)); + addHostFunc("inflateEnd", std::make_unique(Env)); + addHostFunc("deflateInit2", std::make_unique(Env)); + addHostFunc("deflateSetDictionary", + std::make_unique(Env)); + addHostFunc("deflateGetDictionary", + std::make_unique(Env)); + addHostFunc("deflateCopy", std::make_unique(Env)); + addHostFunc("deflateReset", std::make_unique(Env)); + addHostFunc("deflateParams", + std::make_unique(Env)); + addHostFunc("deflateTune", std::make_unique(Env)); + addHostFunc("deflateBound", std::make_unique(Env)); + addHostFunc("deflatePending", + std::make_unique(Env)); + addHostFunc("deflatePrime", std::make_unique(Env)); + addHostFunc("deflateSetHeader", + std::make_unique(Env)); + addHostFunc("inflateInit2", std::make_unique(Env)); + addHostFunc("inflateSetDictionary", + std::make_unique(Env)); + addHostFunc("inflateGetDictionary", + std::make_unique(Env)); + addHostFunc("inflateSync", std::make_unique(Env)); + addHostFunc("inflateCopy", std::make_unique(Env)); + addHostFunc("inflateReset", std::make_unique(Env)); + addHostFunc("inflateReset2", + std::make_unique(Env)); + addHostFunc("inflatePrime", std::make_unique(Env)); + addHostFunc("inflateMark", std::make_unique(Env)); + addHostFunc("inflateGetHeader", + std::make_unique(Env)); + addHostFunc("inflateBackInit", + std::make_unique(Env)); + addHostFunc("inflateBackEnd", + std::make_unique(Env)); + addHostFunc("zlibCompileFlags", + std::make_unique(Env)); + addHostFunc("compress", std::make_unique(Env)); + addHostFunc("compress2", std::make_unique(Env)); + addHostFunc("compressBound", + std::make_unique(Env)); + addHostFunc("uncompress", std::make_unique(Env)); + addHostFunc("uncompress2", std::make_unique(Env)); + addHostFunc("gzopen", std::make_unique(Env)); + addHostFunc("gzdopen", std::make_unique(Env)); + addHostFunc("gzbuffer", std::make_unique(Env)); + addHostFunc("gzsetparams", std::make_unique(Env)); + addHostFunc("gzread", std::make_unique(Env)); + addHostFunc("gzfread", std::make_unique(Env)); + addHostFunc("gzwrite", std::make_unique(Env)); + addHostFunc("gzfwrite", std::make_unique(Env)); + addHostFunc("gzputs", std::make_unique(Env)); + addHostFunc("gzputc", std::make_unique(Env)); + addHostFunc("gzgetc", std::make_unique(Env)); + addHostFunc("gzungetc", std::make_unique(Env)); + addHostFunc("gzflush", std::make_unique(Env)); + addHostFunc("gzseek", std::make_unique(Env)); + addHostFunc("gzrewind", std::make_unique(Env)); + addHostFunc("gztell", std::make_unique(Env)); + addHostFunc("gzoffset", std::make_unique(Env)); + addHostFunc("gzeof", std::make_unique(Env)); + addHostFunc("gzdirect", std::make_unique(Env)); + addHostFunc("gzclose", std::make_unique(Env)); + addHostFunc("gzclose_r", std::make_unique(Env)); + addHostFunc("gzclose_w", std::make_unique(Env)); + addHostFunc("gzclearerr", std::make_unique(Env)); + addHostFunc("adler32", std::make_unique(Env)); + addHostFunc("adler32_z", std::make_unique(Env)); + addHostFunc("adler32_combine", + std::make_unique(Env)); + addHostFunc("crc32", std::make_unique(Env)); + addHostFunc("crc32_z", std::make_unique(Env)); + addHostFunc("crc32_combine", std::make_unique(Env)); + addHostFunc("deflateInit_", std::make_unique(Env)); + addHostFunc("inflateInit_", std::make_unique(Env)); + addHostFunc("deflateInit2_", + std::make_unique(Env)); + addHostFunc("inflateInit2_", + std::make_unique(Env)); + addHostFunc("inflateBackInit_", + std::make_unique(Env)); + addHostFunc("gzgetc_", std::make_unique(Env)); + addHostFunc("inflateSyncPoint", + std::make_unique(Env)); + addHostFunc("inflateUndermine", + std::make_unique(Env)); + addHostFunc("inflateValidate", + std::make_unique(Env)); + addHostFunc("inflateCodesUsed", + std::make_unique(Env)); + addHostFunc("inflateResetKeep", + std::make_unique(Env)); + addHostFunc("deflateResetKeep", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibmodule.h b/plugins/wasmedge_zlib/zlibmodule.h new file mode 100644 index 00000000..dd595124 --- /dev/null +++ b/plugins/wasmedge_zlib/zlibmodule.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "runtime/instance/module.h" +#include "zlibenv.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeZlibModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeZlibModule(); + + WasmEdgeZlibEnvironment &getEnv() { return Env; } + +private: + WasmEdgeZlibEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt new file mode 100644 index 00000000..7ead52f7 --- /dev/null +++ b/test/plugins/CMakeLists.txt @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +# WASI plug-in: WASI-Crypto proposal. +if(WASMEDGE_PLUGIN_WASI_CRYPTO) + add_subdirectory(wasi_crypto) +endif() + +# WASI plug-in: WASI-Logging proposal. +add_subdirectory(wasi_logging) + +# WASI plug-in: WASI-NN proposal with backends. +if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) + add_subdirectory(wasi_nn) +endif() + +# WasmEdge plug-in: wasm-bpf. +if(WASMEDGE_PLUGIN_WASM_BPF) + # wasm_bpf is currently supported only on Linux systems. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasm_bpf) + else() + message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") + endif() +endif() + +# WasmEdge plug-in: ffmpeg. +if(WASMEDGE_PLUGIN_FFMPEG) + add_subdirectory(wasmedge_ffmpeg) +endif() + +# WasmEdge plug-in: Image. +if(WASMEDGE_PLUGIN_IMAGE) + # wasmedge_image is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_image) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Image plug-in now.") + endif() +endif() + +# WasmEdge plug-in: LLMC. +if(WASMEDGE_PLUGIN_LLMC) + add_subdirectory(wasmedge_llmc) +endif() + +# WasmEdge plug-in: OpenCV-mini. +if(WASMEDGE_PLUGIN_OPENCVMINI) + # wasmedge_opencvmini is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_opencvmini) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_OpenCVMini plug-in now.") + endif() +endif() + +# WasmEdge plug-in: Process. +if(WASMEDGE_PLUGIN_PROCESS) + # wasmedge_process is currently supported only on Linux systems. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_process) + else() + message(WARNING "Only Linux platforms support WasmEdge_Process plug-in now.") + endif() +endif() + +# WasmEdge plug-in: Stable-diffusion. +if(WASMEDGE_PLUGIN_STABLEDIFFUSION) + # wasmedge_stablediffusion is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_stablediffusion) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_StableDiffusion plug-in now.") + endif() +endif() + +# WasmEdge plug-in: TensorFlow. +if(WASMEDGE_PLUGIN_TENSORFLOW) + # wasmedge_tensorflow is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflow) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Tensorflow plug-in now.") + endif() +endif() + +# WasmEdge plug-in: TensorFlow-Lite. +if(WASMEDGE_PLUGIN_TENSORFLOWLITE) + # wasmedge_tensorflowlite is currently supported only on Linux and macOS. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflowlite) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_TensorflowLite plug-in now.") + endif() +endif() + +# WasmEdge plug-in: zlib. +if(WASMEDGE_PLUGIN_ZLIB) + add_subdirectory(wasmedge_zlib) +endif() + +# Plug-in unit tests. +add_subdirectory(unittest) diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt new file mode 100644 index 00000000..8a0e2ad5 --- /dev/null +++ b/test/plugins/unittest/CMakeLists.txt @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +# The test plugin module in C API +enable_language(C) + +wasmedge_add_library(wasmedgePluginTestModuleC + SHARED + testplugin.c + ) + +set_target_properties(wasmedgePluginTestModuleC PROPERTIES + C_STANDARD 11 +) + +# remove cxx_standard for msvc +set_property(TARGET wasmedgePluginTestModuleC PROPERTY + CXX_STANDARD +) + +target_compile_options(wasmedgePluginTestModuleC + PUBLIC + -DWASMEDGE_PLUGIN +) + +# The test plugin module in C++ API +wasmedge_add_library(wasmedgePluginTestModuleCPP + SHARED + testplugin.cpp +) + +target_compile_options(wasmedgePluginTestModuleCPP + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginTestModuleCPP + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +# The test executable for C API +wasmedge_add_executable(wasmedgePluginUnittestsC + unittest_c.cpp +) + +target_include_directories(wasmedgePluginUnittestsC + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + $ +) + +target_link_libraries(wasmedgePluginUnittestsC + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# The test executable for C++ API +wasmedge_add_executable(wasmedgePluginUnittestsCPP + unittest_cpp.cpp +) + +target_include_directories(wasmedgePluginUnittestsCPP + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + $ +) + +target_link_libraries(wasmedgePluginUnittestsCPP + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginTestModuleC + PRIVATE + wasmedgeCAPI + ) + target_link_libraries(wasmedgePluginTestModuleCPP + PRIVATE + wasmedgeCAPI + ) + target_link_libraries(wasmedgePluginUnittestsC + PRIVATE + wasmedgeCAPI + ) + target_link_libraries(wasmedgePluginUnittestsCPP + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginTestModuleC + PRIVATE + wasmedge_shared + ) + target_link_libraries(wasmedgePluginTestModuleCPP + PRIVATE + wasmedge_shared + ) + target_link_libraries(wasmedgePluginUnittestsC + PRIVATE + wasmedge_shared + ) + target_link_libraries(wasmedgePluginUnittestsCPP + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgePluginUnittestsC wasmedgePluginUnittestsC) +add_test(wasmedgePluginUnittestsCPP wasmedgePluginUnittestsCPP) diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c new file mode 100644 index 00000000..f9dac55f --- /dev/null +++ b/test/plugins/unittest/testplugin.c @@ -0,0 +1,123 @@ + +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasmedge/wasmedge.h" + +#include +#include +#include + +static WasmEdge_String NameString; +static const char NameCString[] = "name"; +static const WasmEdge_String NameStringDefaultValue = {.Buf = NameCString, + .Length = 4}; +static void Finalizer(void *Data) { + printf("Deallocate host data\n"); + free((int32_t *)Data); +} + +static WasmEdge_Result +HostFuncAdd(void *Data, const WasmEdge_CallingFrameContext *CallFrameCxt, + const WasmEdge_Value *In, WasmEdge_Value *Out) { + (void)CallFrameCxt; + /* + * Host function to calculate A + B, + * and accumulate (A + B) to the host data. + */ + int32_t Val1 = WasmEdge_ValueGetI32(In[0]); + int32_t Val2 = WasmEdge_ValueGetI32(In[1]); + Out[0] = WasmEdge_ValueGenI32(Val1 + Val2); + int32_t *Accum = (int32_t *)Data; + *Accum += WasmEdge_ValueGetI32(Out[0]); + printf("Current accumulate: %d\n", *Accum); + return WasmEdge_Result_Success; +} + +static WasmEdge_Result +HostFuncSub(void *Data, const WasmEdge_CallingFrameContext *CallFrameCxt, + const WasmEdge_Value *In, WasmEdge_Value *Out) { + (void)CallFrameCxt; + /* + * Host function to calculate A - B, + * and accumulate (A - B) to the host data. + */ + int32_t Val1 = WasmEdge_ValueGetI32(In[0]); + int32_t Val2 = WasmEdge_ValueGetI32(In[1]); + Out[0] = WasmEdge_ValueGenI32(Val1 - Val2); + int32_t *Accum = (int32_t *)Data; + *Accum += WasmEdge_ValueGetI32(Out[0]); + printf("Current accumulate: %d\n", *Accum); + return WasmEdge_Result_Success; +} + +static WasmEdge_ModuleInstanceContext * +CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc) { + /* Allocate and initialize a host data. */ + printf("Allocate host data\n"); + int32_t *Accumulate = (int32_t *)malloc(sizeof(int32_t)); + *Accumulate = 0; + + /* Create the module instance. */ + WasmEdge_String ModuleName = WasmEdge_StringCreateByCString(Desc->Name); + WasmEdge_ModuleInstanceContext *Mod = + WasmEdge_ModuleInstanceCreateWithData(ModuleName, Accumulate, Finalizer); + WasmEdge_StringDelete(ModuleName); + + WasmEdge_String FuncName; + WasmEdge_FunctionTypeContext *FType; + WasmEdge_FunctionInstanceContext *FuncCxt; + WasmEdge_ValType ParamTypes[2], ReturnTypes[1]; + ParamTypes[0] = WasmEdge_ValTypeGenI32(); + ParamTypes[1] = WasmEdge_ValTypeGenI32(); + ReturnTypes[0] = WasmEdge_ValTypeGenI32(); + + /* Create the "add" function and add it to the module instance. */ + FType = WasmEdge_FunctionTypeCreate(ParamTypes, 2, ReturnTypes, 1); + FuncName = WasmEdge_StringCreateByCString("add"); + FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncAdd, Accumulate, 0); + WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); + WasmEdge_StringDelete(FuncName); + /* Create the "sub" function and add it to the module instance. */ + FuncName = WasmEdge_StringCreateByCString("sub"); + FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncSub, Accumulate, 0); + WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); + WasmEdge_StringDelete(FuncName); + WasmEdge_FunctionTypeDelete(FType); + + return Mod; +} + +static WasmEdge_ProgramOption PODesc[] = {{ + .Name = "name", + .Description = "test name string", + .Type = WasmEdge_ProgramOptionType_String, + .Storage = &NameString, + .DefaultValue = &NameStringDefaultValue, +}}; +static WasmEdge_ModuleDescriptor ModuleDesc[] = {{ + .Name = "wasmedge_plugintest_c_module", + .Description = "This is for the plugin tests in WasmEdge C API.", + .Create = CreateTestModule, +}}; +static WasmEdge_PluginDescriptor Desc = { + .Name = "wasmedge_plugintest_c", + .Description = "", + .APIVersion = WasmEdge_Plugin_CurrentAPIVersion, + .Version = + { + .Major = 0, + .Minor = 10, + .Patch = 0, + .Build = 0, + }, + .ModuleCount = 1, + .ProgramOptionCount = 1, + .ModuleDescriptions = ModuleDesc, + .ProgramOptions = PODesc, +}; + +WASMEDGE_CAPI_PLUGIN_EXPORT const WasmEdge_PluginDescriptor * +WasmEdge_Plugin_GetDescriptor(void) { + return &Desc; +} diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp new file mode 100644 index 00000000..95731844 --- /dev/null +++ b/test/plugins/unittest/testplugin.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "testplugin.h" +#include "po/helper.h" + +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals::string_view_literals; + +PO::List + WasmEdgePluginTestEnv::CmdArgs(PO::Description("Test for args."sv), + PO::MetaVar("ARG"sv)); + +PO::Option + WasmEdgePluginTestEnv::CmdName(PO::Description("Test for input name."sv), + PO::DefaultValue(std::string(""))); + +PO::Option + WasmEdgePluginTestEnv::CmdOpt(PO::Description("Test for option."sv)); + +namespace { + +void addOptions(const Plugin::Plugin::PluginDescriptor *, + PO::ArgumentParser &Parser) noexcept { + Parser.add_option("arg"sv, WasmEdgePluginTestEnv::CmdArgs) + .add_option("name"sv, WasmEdgePluginTestEnv::CmdName) + .add_option("opt"sv, WasmEdgePluginTestEnv::CmdOpt); +} + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgePluginTestModule; +} + +static Plugin::PluginModule::ModuleDescriptor MD[]{ + { + /* Name */ "wasmedge_plugintest_cpp_module", + /* Description */ "This is for the plugin tests in WasmEdge.", + /* Create */ create, + }, +}; + +const Plugin::Plugin::PluginDescriptor Descriptor{ + /* Name */ "wasmedge_plugintest_cpp", + /* Description */ "", + /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, + /* Version */ {0, 10, 0, 0}, + /* ModuleCount */ 1, + /* ModuleDescriptions */ MD, + /* ComponentCount */ 0, + /* ComponentDescriptions */ nullptr, + /* AddOptions */ addOptions, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h new file mode 100644 index 00000000..82376b3e --- /dev/null +++ b/test/plugins/unittest/testplugin.h @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "po/argument_parser.h" +#include "po/list.h" +#include "po/option.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgePluginTestEnv { +public: + WasmEdgePluginTestEnv() noexcept = default; + + static PO::List CmdArgs; + static PO::Option CmdName; + static PO::Option CmdOpt; +}; + +template +class WasmEdgePluginTestFunc : public Runtime::HostFunction { +public: + WasmEdgePluginTestFunc(WasmEdgePluginTestEnv &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgePluginTestEnv &Env; +}; + +class WasmEdgePluginTestFuncAdd + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncAdd(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t A, uint32_t B) { + return A + B; + } +}; + +class WasmEdgePluginTestFuncSub + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncSub(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t A, uint32_t B) { + return A - B; + } +}; + +class WasmEdgePluginTestFuncArgLen + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncArgLen(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &) { + return static_cast(Env.CmdArgs.value().size()); + } +}; + +class WasmEdgePluginTestFuncOpt + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncOpt(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &) { + return static_cast(Env.CmdOpt.value()); + } +}; + +class WasmEdgePluginTestFuncNameSize + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncNameSize(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &) { + return static_cast(Env.CmdName.value().size()); + } +}; + +class WasmEdgePluginTestModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgePluginTestModule() + : Runtime::Instance::ModuleInstance("wasmedge_plugintest_cpp_module") { + addHostFunc("add", std::make_unique(Env)); + addHostFunc("sub", std::make_unique(Env)); + addHostFunc("arg_len", std::make_unique(Env)); + addHostFunc("opt", std::make_unique(Env)); + addHostFunc("name_size", + std::make_unique(Env)); + } + + WasmEdgePluginTestEnv &getEnv() { return Env; } + +private: + WasmEdgePluginTestEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/unittest/unittest_c.cpp b/test/plugins/unittest/unittest_c.cpp new file mode 100644 index 00000000..1943648b --- /dev/null +++ b/test/plugins/unittest/unittest_c.cpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "wasmedge/wasmedge.h" + +#include + +#include + +namespace { +WasmEdge_ModuleInstanceContext *createModuleC() { + WasmEdge_PluginLoadFromPath( + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION); + WasmEdge_String Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_c"); + const WasmEdge_PluginContext *PluginCxt = WasmEdge_PluginFind(Str); + WasmEdge_StringDelete(Str); + if (!PluginCxt) { + return nullptr; + } + + Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_c_module"); + WasmEdge_ModuleInstanceContext *ModCxt = + WasmEdge_PluginCreateModule(PluginCxt, Str); + WasmEdge_StringDelete(Str); + return ModCxt; +} + +WasmEdge_ModuleInstanceContext *createModuleCPP() { + WasmEdge_PluginLoadFromPath( + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION); + WasmEdge_String Str = + WasmEdge_StringCreateByCString("wasmedge_plugintest_cpp"); + const WasmEdge_PluginContext *PluginCxt = WasmEdge_PluginFind(Str); + WasmEdge_StringDelete(Str); + if (!PluginCxt) { + return nullptr; + } + + Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_cpp_module"); + WasmEdge_ModuleInstanceContext *ModCxt = + WasmEdge_PluginCreateModule(PluginCxt, Str); + WasmEdge_StringDelete(Str); + return ModCxt; +} +} // namespace + +TEST(wasmedgePluginTests, C_Run) { + auto *ExecCxt = WasmEdge_ExecutorCreate(nullptr, nullptr); + auto *StoreCxt = WasmEdge_StoreCreate(); + WasmEdge_Result Res; + WasmEdge_FunctionInstanceContext *FuncCxt; + WasmEdge_String FuncName; + WasmEdge_Value Params[2], Returns[1]; + + // Create the wasmedge_plugintest_c_module module instance. + auto *ModInstC = createModuleC(); + EXPECT_FALSE(ModInstC == nullptr); + int32_t *HostData = + static_cast(WasmEdge_ModuleInstanceGetHostData(ModInstC)); + EXPECT_FALSE(HostData == nullptr); + + // Create the wasmedge_plugintest_cpp_module module instance. + auto *ModInstCPP = createModuleCPP(); + EXPECT_FALSE(ModInstCPP == nullptr); + + // Test: Run the function "wasmedge_plugintest_c.add". + FuncName = WasmEdge_StringCreateByCString("add"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstC, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(111); + Params[1] = WasmEdge_ValueGenI32(333); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 444); + EXPECT_EQ(*HostData, 444); + + // Test: Run the function "wasmedge_plugintest_c.sub". + FuncName = WasmEdge_StringCreateByCString("sub"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstC, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(666); + Params[1] = WasmEdge_ValueGenI32(555); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 111); + EXPECT_EQ(*HostData, 555); + + // Test: Run the function "wasmedge_plugintest_cpp.add". + FuncName = WasmEdge_StringCreateByCString("add"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(111); + Params[1] = WasmEdge_ValueGenI32(333); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 444); + + // Test: Run the function "wasmedge_plugintest_cpp.sub". + FuncName = WasmEdge_StringCreateByCString("sub"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(666); + Params[1] = WasmEdge_ValueGenI32(555); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 111); + + // Test: Run the function "wasmedge_plugintest_cpp.arg_len". + FuncName = WasmEdge_StringCreateByCString("arg_len"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, nullptr, 0, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 0); + + // Test: Run the function "wasmedge_plugintest_cpp.name_size". + FuncName = WasmEdge_StringCreateByCString("name_size"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, nullptr, 0, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 0); + + WasmEdge_ExecutorDelete(ExecCxt); + WasmEdge_StoreDelete(StoreCxt); + WasmEdge_ModuleInstanceDelete(ModInstC); + WasmEdge_ModuleInstanceDelete(ModInstCPP); +} + +TEST(wasmedgePluginTests, C_Module) { + WasmEdge_String NameBuf[16]; + + // Create the wasmedge_plugintest_c_module module instance. + auto *ModInstC = createModuleC(); + ASSERT_FALSE(ModInstC == nullptr); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunctionLength(ModInstC), 2U); + std::memset(NameBuf, 0, sizeof(WasmEdge_String) * 16); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunction(ModInstC, NameBuf, 16), 2U); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[0], WasmEdge_StringWrap("add", 3U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[1], WasmEdge_StringWrap("sub", 3U))); + WasmEdge_ModuleInstanceDelete(ModInstC); + + // Create the wasmedge_plugintest_cpp_module module instance. + auto *ModInstCPP = createModuleCPP(); + ASSERT_FALSE(ModInstCPP == nullptr); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunctionLength(ModInstCPP), 5U); + std::memset(NameBuf, 0, sizeof(WasmEdge_String) * 16); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunction(ModInstCPP, NameBuf, 16), 5U); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[0], WasmEdge_StringWrap("add", 3U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[1], WasmEdge_StringWrap("arg_len", 7U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[2], WasmEdge_StringWrap("name_size", 9U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[3], WasmEdge_StringWrap("opt", 3U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[4], WasmEdge_StringWrap("sub", 3U))); + WasmEdge_ModuleInstanceDelete(ModInstCPP); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/unittest/unittest_cpp.cpp b/test/plugins/unittest/unittest_cpp.cpp new file mode 100644 index 00000000..f800c29e --- /dev/null +++ b/test/plugins/unittest/unittest_cpp.cpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "plugin/plugin.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +std::unique_ptr createModuleC() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_plugintest_c"sv)) { + if (const auto *Module = + Plugin->findModule("wasmedge_plugintest_c_module"sv)) { + return Module->create(); + } + } + return {}; +} + +std::unique_ptr createModuleCPP() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_plugintest_cpp"sv)) { + WasmEdge::PO::ArgumentParser Parser; + Plugin->registerOptions(Parser); + Parser.set_raw_value("name"sv, std::string("test_name")); + Parser.set_raw_value>( + "arg"sv, std::vector({"arg0", "arg1", "arg2", "arg3"})); + Parser.set_raw_value("opt"sv); + if (const auto *Module = + Plugin->findModule("wasmedge_plugintest_cpp_module"sv)) { + return Module->create(); + } + } + return {}; +} +} // namespace + +TEST(wasmedgePluginTests, CPP_Run) { + // Create the wasmedge_plugintest_cpp_module module instance. + auto TestModCPP = createModuleCPP(); + ASSERT_TRUE(TestModCPP); + + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array RetVal; + + // Get the function "arg_len". + auto *FuncInst1 = TestModCPP->findFuncExports("arg_len"); + EXPECT_NE(FuncInst1, nullptr); + EXPECT_TRUE(FuncInst1->isHostFunction()); + auto &HostFuncInst1 = FuncInst1->getHostFunc(); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst1.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 4); + + // Get the function "name_size". + auto *FuncInst2 = TestModCPP->findFuncExports("name_size"); + EXPECT_NE(FuncInst2, nullptr); + EXPECT_TRUE(FuncInst2->isHostFunction()); + auto &HostFuncInst2 = FuncInst2->getHostFunc(); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst2.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 9); + + // Get the function "opt". + auto *FuncInst3 = TestModCPP->findFuncExports("opt"); + EXPECT_NE(FuncInst3, nullptr); + EXPECT_TRUE(FuncInst3->isHostFunction()); + auto &HostFuncInst3 = FuncInst3->getHostFunc(); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst3.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 1); + + // Create the wasmedge_plugintest_c_module module instance. + auto TestModC = createModuleC(); + ASSERT_TRUE(TestModC); + // The host functions are implemented in the C API. + // Therefore not test to invoke them here. +} + +TEST(wasmedgePluginTests, CPP_Module) { + // Create the wasmedge_plugintest_cpp_module module instance. + auto TestModCPP = createModuleCPP(); + ASSERT_TRUE(TestModCPP); + EXPECT_EQ(TestModCPP->getFuncExportNum(), 5U); + EXPECT_NE(TestModCPP->findFuncExports("add"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("sub"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("arg_len"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("opt"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("name_size"), nullptr); + + // Create the wasmedge_plugintest_c_module module instance. + auto TestModC = createModuleC(); + ASSERT_TRUE(TestModC); + EXPECT_EQ(TestModC->getFuncExportNum(), 2U); + EXPECT_NE(TestModC->findFuncExports("add"), nullptr); + EXPECT_NE(TestModC->findFuncExports("sub"), nullptr); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasi_crypto/CMakeLists.txt b/test/plugins/wasi_crypto/CMakeLists.txt new file mode 100644 index 00000000..69fc8482 --- /dev/null +++ b/test/plugins/wasi_crypto/CMakeLists.txt @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasiCryptoTests + aeads.cpp + asymmetric.cpp + common.cpp + hash.cpp + helper.cpp + kdf.cpp + kx.cpp + mac.cpp + notimplement.cpp + signatures.cpp +) + +add_dependencies(wasiCryptoTests + wasmedgePluginWasiCrypto +) + +target_compile_options(wasiCryptoTests + PUBLIC + -DOPENSSL_API_COMPAT=0x10100000L +) + +target_include_directories(wasiCryptoTests + PUBLIC + $ + $ +) + +target_link_libraries(wasiCryptoTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasiCryptoTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasiCryptoTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasiCryptoTests wasiCryptoTests) diff --git a/test/plugins/wasi_crypto/aeads.cpp b/test/plugins/wasi_crypto/aeads.cpp new file mode 100644 index 00000000..bae54a26 --- /dev/null +++ b/test/plugins/wasi_crypto/aeads.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Aeads) { + auto AeadsTest = [this](std::string_view Name, + const std::vector &Nonce, size_t MaxTagSize, + const std::vector &Msg) { + SCOPED_TRACE(Name); + + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, + symmetricKeyGenerate(Name, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(OptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_SYMMETRIC)); + // Repeatedly set and overwrite the previous option. + WASI_CRYPTO_EXPECT_TRUE(optionsSet(OptionsHandle, "nonce"sv, "nonce"_u8)); + WASI_CRYPTO_EXPECT_TRUE(optionsSet(OptionsHandle, "nonce"sv, Nonce)); + WASI_CRYPTO_EXPECT_SUCCESS( + State1Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + + // State nonce equal to the previous set one. + std::vector ObservedNonce(Nonce.size()); + symmetricStateOptionsGet(State1Handle, "nonce"sv, ObservedNonce); + EXPECT_EQ(ObservedNonce, Nonce); + WASI_CRYPTO_EXPECT_SUCCESS(TagSize, symmetricStateMaxTagLen(State1Handle)); + EXPECT_EQ(TagSize, MaxTagSize); + + std::vector CiphertextWithTag(Msg.size() + MaxTagSize); + + WASI_CRYPTO_EXPECT_SUCCESS( + OutTagSize, + symmetricStateEncrypt(State1Handle, CiphertextWithTag, Msg)); + EXPECT_EQ(OutTagSize, CiphertextWithTag.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State1Handle)); + + { + WASI_CRYPTO_EXPECT_SUCCESS( + State2Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + std::vector Msg2(Msg.size()); + WASI_CRYPTO_EXPECT_SUCCESS( + OutputMsg2Size, + symmetricStateDecrypt(State2Handle, Msg2, CiphertextWithTag)); + EXPECT_EQ(OutputMsg2Size, Msg2.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State2Handle)); + EXPECT_EQ(Msg2, Msg); + } + + WASI_CRYPTO_EXPECT_SUCCESS( + State3Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + std::vector Ciphertext(Msg.size()); + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, symmetricStateEncryptDetached( + State3Handle, Ciphertext, Msg)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State3Handle)); + std::vector Tag(MaxTagSize); + WASI_CRYPTO_EXPECT_SUCCESS(OutputTagSize, symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(OutputTagSize, MaxTagSize); + EXPECT_EQ(Tag, + std::vector( + CiphertextWithTag.begin() + + static_cast( + Msg.size()), + CiphertextWithTag.end())); + + WASI_CRYPTO_EXPECT_SUCCESS( + State4Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + std::vector Msg3(Msg.size()); + symmetricStateDecryptDetached(State4Handle, Msg3, Ciphertext, Tag); + EXPECT_EQ("test"_u8, Msg3); + WASI_CRYPTO_EXPECT_TRUE(optionsClose(OptionsHandle)); + + { + // Some error cases checking. + EXPECT_TRUE( + symmetricStateOpen(Name, InvaildHandle, std::nullopt).error() == + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGet(State4Handle, "foo"sv, {}), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGetU64(State4Handle, "foo"sv), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeTag(State4Handle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeKey(State4Handle, Name), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueeze(State4Handle, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(State4Handle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + bool IsSymmetricStateCloneImplemented = true; + // XXX: These cipher didn't implement context duplication from OpenSSL 3.0.0 + // https://github.com/openssl/openssl/issues/20978 + if (0x30000000 <= OPENSSL_VERSION_NUMBER && + (Name == "AES-128-GCM"sv || Name == "AES-256-GCM"sv || + Name == "CHACHA20-POLY1305"sv)) { + IsSymmetricStateCloneImplemented = false; + } + + if (IsSymmetricStateCloneImplemented) { + // Clone checking. + WASI_CRYPTO_EXPECT_SUCCESS(NewStateHandle, + symmetricStateClone(State4Handle)); + EXPECT_NE(State4Handle, NewStateHandle); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(NewStateHandle)); + } else { + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateClone(State4Handle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State4Handle)); + }; + + AeadsTest("AES-128-GCM"sv, std::vector(12, 42), 16, "test"_u8); + AeadsTest("AES-256-GCM"sv, std::vector(12, 42), 16, "test"_u8); + AeadsTest("CHACHA20-POLY1305"sv, std::vector(12, 42), 16, "test"_u8); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp new file mode 100644 index 00000000..2a2b5018 --- /dev/null +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -0,0 +1,642 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Asymmetric) { + auto EncodingCheck = + [this](std::string_view Alg, __wasi_algorithm_type_e_t AlgType, + std::map<__wasi_publickey_encoding_e_t, std::vector> + SupportPk, + std::map<__wasi_secretkey_encoding_e_t, std::vector> + SupportSk, + std::map<__wasi_keypair_encoding_e_t, std::vector> + SupportKp) { + SCOPED_TRACE(Alg); + + // Function checking. + { + WASI_CRYPTO_EXPECT_SUCCESS( + KpHandle, keypairGenerate(AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(PkHandle, keypairPublickey(KpHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(SkHandle, keypairSecretkey(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(keypairClose(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(PkHandle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(SkHandle)); + + if (Alg == "Ed25519"sv) { + WASI_CRYPTO_EXPECT_SUCCESS( + KpMHandle, + keypairGenerateManaged(1, AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(PkMHandle, keypairPublickey(KpMHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(SkMHandle, keypairSecretkey(KpMHandle)); + WASI_CRYPTO_EXPECT_TRUE(keypairClose(KpMHandle)); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(PkMHandle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(SkMHandle)); + } else { + WASI_CRYPTO_EXPECT_FAILURE( + keypairGenerateManaged(1, AlgType, Alg, std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + } + + // Encoding checking. + for (auto &&[PkEncoding, Pk] : SupportPk) { + SCOPED_TRACE("Public key encoding"); + SCOPED_TRACE(PkEncoding); + WASI_CRYPTO_EXPECT_SUCCESS( + PkHandle, publickeyImport(AlgType, Alg, Pk, PkEncoding)); + + std::vector ExportPk(Pk.size()); + WASI_CRYPTO_EXPECT_SUCCESS(PkOutputHandle, + publickeyExport(PkHandle, PkEncoding)); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(PkOutputHandle, ExportPk)); + EXPECT_EQ(ExportPk, Pk); + } + for (auto &&[SkEncoding, Sk] : SupportSk) { + SCOPED_TRACE("Secret key encoding"); + SCOPED_TRACE(SkEncoding); + WASI_CRYPTO_EXPECT_SUCCESS( + SkHandle, secretkeyImport(AlgType, Alg, Sk, SkEncoding)); + + std::vector ExportSk(Sk.size()); + WASI_CRYPTO_EXPECT_SUCCESS(SkOutputHandle, + secretkeyExport(SkHandle, SkEncoding)); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(SkOutputHandle, ExportSk)); + EXPECT_EQ(ExportSk, Sk); + } + for (auto &&[KpEncoding, Kp] : SupportKp) { + SCOPED_TRACE("Key Pair encoding"); + SCOPED_TRACE(KpEncoding); + WASI_CRYPTO_EXPECT_SUCCESS( + KpHandle, keypairImport(AlgType, Alg, Kp, KpEncoding)); + + std::vector ExportKp(Kp.size()); + WASI_CRYPTO_EXPECT_SUCCESS(KpOutputHandle, + keypairExport(KpHandle, KpEncoding)); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(KpOutputHandle, ExportKp)); + EXPECT_EQ(ExportKp, Kp); + } + }; + EncodingCheck( + "X25519"sv, __WASI_ALGORITHM_TYPE_KEY_EXCHANGE, + {{__WASI_PUBLICKEY_ENCODING_RAW, + "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"_u8v}}, + {{__WASI_KEYPAIR_ENCODING_RAW, + "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"_u8v}}); + EncodingCheck( + "ECDSA_P256_SHA256"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_SEC, + "0460FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB67903FE1008B8BC99A41AE9E95628BC64F2F1B20C2D7E9F5177A3C294D4462299"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "3059301306072a8648ce3d020106082a8648ce3d0301070342000460FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB67903FE1008B8BC99A41AE9E95628BC64F2F1B20C2D7E9F5177A3C294D4462299"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEYP7UuiVanTHJYet0xjVtaMBJuJI7\n" + "Yfps5mliLmDyn7Z5A/4QCLi8maQa6elWKLxk8vGyDC1+n1F3o8KU1EYimQ==\n" + "-----END PUBLIC KEY-----\n"_u8}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "C9AFA9D845BA75166B5C215767B1D6934E50C3DB36E89B127B8A622B120F6721"_u8v}, + {__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgya+p2EW6dRZrXCFX\n" + "Z7HWk05Qw9s26JsSe4piKxIPZyGhRANCAARg/tS6JVqdMclh63TGNW1owEm4kjth\n" + "+mzmaWIuYPKftnkD/hAIuLyZpBrp6VYovGTy8bIMLX6fUXejwpTURiKZ\n" + "-----END PRIVATE KEY-----\n"_u8}}, + {}); + EncodingCheck( + "ECDSA_K256_SHA256"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_SEC, + "04b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6ff0c9d75bfba7b31a6bca1974496eeb56de357071955d83c4b1badaa0b21832e9"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "3056301006072a8648ce3d020106052b8104000a03420004b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6ff0c9d75bfba7b31a6bca1974496eeb56de357071955d83c4b1badaa0b21832e9"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEuDj/ROW8F3vyEYnQdmCC/J2EMiaIf8l2\n" + "A3EQC37iCm/wyddb+6ezGmvKGXRJbutW3jVwcZVdg8Sxutqgshgy6Q==\n" + "-----END PUBLIC KEY-----"_u8}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "b9aa5c28ef96d750e47f4ba44d5d6a7ac3ab6988d292e7819e362a4b0ac8e250"_u8v}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "308184020100301006072a8648ce3d020106052b8104000a046d306b02010104" + "207778b8225c02cc7f2ebcd0a47e2c4fcebd6716a329bdf2e4f961fa35041cba" + "97a1440342000434e2dea3923666bc28779bcd84fba5b4ee97bb8f6ec3cdc0d8" + "6609f6c8b8b9ca81592cdf4d3aeccdacb092e94e8f814265f46e3eefb49ad43c" + "3968e69d4faef4"_u8v}, + {__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQguapcKO+W11Dkf0ukTV1q\n" + "esOraYjSkueBnjYqSwrI4lChRANCAAR/744haGNwx9NDmS8UstRaJizWpcdQMnNv\n" + "y7AvRqme3w4dEUzck5Vsx1ZIv9OPqDKoITXVwrpjR2aodT9tiKrl\n" + "-----END PRIVATE KEY-----\n"_u8}}, + {}); + EncodingCheck( + "ECDSA_P384_SHA384"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEUIwwEUeBvHWigFKAYgPs93Pb21mUr2Oa\n" + "CTSiuK0Hmjc25OO5m4Gk92s4HG22R5596FB7nXlZljqEnETE4xDc4Dugv5ZzlTCf\n" + "HLeAvmfv+hFMRzroZ+GS1Xwgl434yH9r\n" + "-----END PUBLIC KEY-----\n"_u8}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAHfLyuXwp7DoNPIvxg\n" + "B5k8zOAyXHFpJ4FF7CIg4zH/UBFb5m8AyT+c9rvvyVcHlEKhZANiAARQjDARR4G8\n" + "daKAUoBiA+z3c9vbWZSvY5oJNKK4rQeaNzbk47mbgaT3azgcbbZHnn3oUHudeVmW\n" + "OoScRMTjENzgO6C/lnOVMJ8ct4C+Z+/6EUxHOuhn4ZLVfCCXjfjIf2s=\n" + "-----END PRIVATE KEY-----\n"_u8}}, + {}); + EncodingCheck( + "Ed25519"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_RAW, + "d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60"_u8v}}, + {{__WASI_KEYPAIR_ENCODING_RAW, + "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a"_u8v}}); + + auto RsaCheck = + [EncodingCheck]( + std::string_view Bit, + std::map<__wasi_publickey_encoding_e_t, std::vector> + SupportPk, + std::map<__wasi_secretkey_encoding_e_t, std::vector> + SupportSk, + std::map<__wasi_keypair_encoding_e_t, std::vector> + SupportKp) { + if (Bit == "2048"sv) { + EncodingCheck("RSA_PSS_2048_SHA256"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PSS_2048_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PSS_2048_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_2048_SHA256"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_2048_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_2048_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + } + if (Bit == "3072"sv) { + EncodingCheck("RSA_PSS_3072_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PSS_3072_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_3072_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_3072_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + } + if (Bit == "4096"sv) { + EncodingCheck("RSA_PSS_4096_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_4096_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + } + }; + + auto ManagedNegativeCheck = [this]( + __wasi_secrets_manager_t SmHandle, + __wasi_algorithm_type_e_t AlgType, + std::string_view Alg, + std::optional<__wasi_options_t> OptOptions, + __wasi_crypto_errno_e_t ExpectedError) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_FAILURE( + keypairGenerateManaged(SmHandle, AlgType, Alg, OptOptions), + ExpectedError); + }; + + RsaCheck( + "2048"sv, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtoGqtGXL6bqbDK0IJ2pJ\n" + "fM7hfMpyYVXWxtEdv5oNErvGWpHKiH1NIeQxn0ks1QFmkHoK8Quvzyn5/anHLxiZ\n" + "ecWbzRA01MTfs1y7HMBlrToZhRTZPFe7VRGJ+Liv1jpIiRHPzqabZrssgs3Kj5fG\n" + "0EYXITaQRfn0kfZcYtJLYi0OvU18DBi64MrLwABr3wqn2UZgMgiw3MhKvyXRybca\n" + "x0ASO1RTxAJIm21XuFWTztHcJBvl66ygDAzzRdOJyPWvG+TuhNXvZ7dtA0N4iU8p\n" + "SwJljzLEzWzwKOgAizx3Q3EdS+9P+pTdKtei9UGWVunoj46kCw+0QasQE958NPa3\n" + "uQIDAQAB\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "30820122300d06092a864886f70d01010105000382010f003082010a0282" + "010100b681aab465cbe9ba9b0cad08276a497ccee17cca726155d6c6d11d" + "bf9a0d12bbc65a91ca887d4d21e4319f492cd50166907a0af10bafcf29f9" + "fda9c72f189979c59bcd1034d4c4dfb35cbb1cc065ad3a198514d93c57bb" + "551189f8b8afd63a488911cfcea69b66bb2c82cdca8f97c6d04617213690" + "45f9f491f65c62d24b622d0ebd4d7c0c18bae0cacbc0006bdf0aa7d94660" + "3208b0dcc84abf25d1c9b71ac740123b5453c402489b6d57b85593ced1dc" + "241be5ebaca00c0cf345d389c8f5af1be4ee84d5ef67b76d034378894f29" + "4b02658f32c4cd6cf028e8008b3c7743711d4bef4ffa94dd2ad7a2f54196" + "56e9e88f8ea40b0fb441ab1013de7c34f6b7b90203010001"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC2gaq0ZcvpupsM\n" + "rQgnakl8zuF8ynJhVdbG0R2/mg0Su8ZakcqIfU0h5DGfSSzVAWaQegrxC6/PKfn9\n" + "qccvGJl5xZvNEDTUxN+zXLscwGWtOhmFFNk8V7tVEYn4uK/WOkiJEc/OpptmuyyC\n" + "zcqPl8bQRhchNpBF+fSR9lxi0ktiLQ69TXwMGLrgysvAAGvfCqfZRmAyCLDcyEq/\n" + "JdHJtxrHQBI7VFPEAkibbVe4VZPO0dwkG+XrrKAMDPNF04nI9a8b5O6E1e9nt20D\n" + "Q3iJTylLAmWPMsTNbPAo6ACLPHdDcR1L70/6lN0q16L1QZZW6eiPjqQLD7RBqxAT\n" + "3nw09re5AgMBAAECggEAJds7p3O+GltEshpqKJLZb3QSPapYk2wUwuS5gPbZY1tj\n" + "x4GaOzmSeEc3K80n6X8C4VEPV/SOoTAZ1M4UrOYzX5jnul90NfYoWLIRdeNKs+Xr\n" + "STmL3gJsrzaWIetdPdiVFymEq17PuT11/CPnsmVPLgB7573Dq2AvpN8vRqhMTq6j\n" + "6JoaErGi6nqhfOZ0J7GfwKEJtMGPWbOwyjsMpLxrW3PAy0YH99UZD8Ocxw39hERH\n" + "8i9SGSjk1ty5/NnpC6K8wJmb7BOUXFl1g00FkH9nI/rydjV3Xc6L+vLeLrTvDE59\n" + "uPtRKoFSCbHmxoARkzm8RapX+R/tsJOUxwwPRMjaAQKBgQDsja822AFkjKoDtK4k\n" + "R6R9FKcvW/mUVeKz4ldMQynA+5adU2cyNI7UVRWYcbJI7eSefIaaFe/aHhA2NSny\n" + "83+IhcPfa+z10TBz93CcLs7szYgBddrjYsQ8Q7gc73dGBRqm4oQW+HpuFqPlU5z1\n" + "NZ7eBtlJJBCdxpilR/bC+GyQYQKBgQDFgo2vg1dMKEIm7YfRSXojpRUcAxS4qCJe\n" + "hOcZePkkd1mHmwBkOkX6MMlUozNiCXTmwXjRMWRgZJktbH0I9pwSHCEzJbmfnF5W\n" + "GUwVdfCEvZ4gIcgDZyhsm4A9+L2qC5Nm3dr30NR7c34K1ZpiK3AJTWoI19uEmY46\n" + "1ryOGH5GWQKBgQCxKdAPGCm636q5ScmefFWSJDSuQIkkckpuhNby095inUqJG5zP\n" + "OhO6rNqWqJhpDFpL5GF+520Sg6+KmbiIL5vVaLFxFEiNNhW+1JPvNRNewPPafCTq\n" + "Zd8ob2NlsGc49rumP0HEXmZ7KtOm/j8wWu9Xw/NaVvtm3wUVzFbgYOQWIQKBgQC+\n" + "T2mOcJOxQilbsQxpUM9rgSmx8BYLR5a2VIEJPlNyG74ct/HMoYnD5TZZY1ejY1FM\n" + "96ceiuUZLFWcOyjPdjA0Ev66deNCND2B4KY7F4VFoh+2/lXnUYLWA4+yJvc53iWN\n" + "vL+8gW/78/DDJ8a2SPyPOhStqLBQOFWfxEGy+U7TIQKBgA+cdl5pm1YcVNhcfGbx\n" + "zjN75MLGkq9QfL6zxqWIv4xUsjyYkwGgqwazMeNmi5KvhgpfUM8A8tJQixXmq/oe\n" + "m8MDtOovmQ3I1S6jYK7V8wyxyqgj/2oetPhRIjvht8IaHuFKQkjv6424vdoukuoM\n" + "3o5BHs2kyvh9kuSthBY9XZnN\n" + "-----END PRIVATE KEY-----\n"_u8}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "308204be020100300d06092a864886f70d0101010500048204a8308204a4" + "0201000282010100b681aab465cbe9ba9b0cad08276a497ccee17cca7261" + "55d6c6d11dbf9a0d12bbc65a91ca887d4d21e4319f492cd50166907a0af1" + "0bafcf29f9fda9c72f189979c59bcd1034d4c4dfb35cbb1cc065ad3a1985" + "14d93c57bb551189f8b8afd63a488911cfcea69b66bb2c82cdca8f97c6d0" + "461721369045f9f491f65c62d24b622d0ebd4d7c0c18bae0cacbc0006bdf" + "0aa7d946603208b0dcc84abf25d1c9b71ac740123b5453c402489b6d57b8" + "5593ced1dc241be5ebaca00c0cf345d389c8f5af1be4ee84d5ef67b76d03" + "4378894f294b02658f32c4cd6cf028e8008b3c7743711d4bef4ffa94dd2a" + "d7a2f5419656e9e88f8ea40b0fb441ab1013de7c34f6b7b9020301000102" + "82010025db3ba773be1a5b44b21a6a2892d96f74123daa58936c14c2e4b9" + "80f6d9635b63c7819a3b39927847372bcd27e97f02e1510f57f48ea13019" + "d4ce14ace6335f98e7ba5f7435f62858b21175e34ab3e5eb49398bde026c" + "af369621eb5d3dd895172984ab5ecfb93d75fc23e7b2654f2e007be7bdc3" + "ab602fa4df2f46a84c4eaea3e89a1a12b1a2ea7aa17ce67427b19fc0a109" + "b4c18f59b3b0ca3b0ca4bc6b5b73c0cb4607f7d5190fc39cc70dfd844447" + "f22f521928e4d6dcb9fcd9e90ba2bcc0999bec13945c5975834d05907f67" + "23faf27635775dce8bfaf2de2eb4ef0c4e7db8fb512a815209b1e6c68011" + "9339bc45aa57f91fedb09394c70c0f44c8da0102818100ec8daf36d80164" + "8caa03b4ae2447a47d14a72f5bf99455e2b3e2574c4329c0fb969d536732" + "348ed455159871b248ede49e7c869a15efda1e10363529f2f37f8885c3df" + "6becf5d13073f7709c2eceeccd880175dae362c43c43b81cef7746051aa6" + "e28416f87a6e16a3e5539cf5359ede06d94924109dc698a547f6c2f86c90" + "6102818100c5828daf83574c284226ed87d1497a23a5151c0314b8a8225e" + "84e71978f9247759879b00643a45fa30c954a333620974e6c178d1316460" + "64992d6c7d08f69c121c213325b99f9c5e56194c1575f084bd9e2021c803" + "67286c9b803df8bdaa0b9366dddaf7d0d47b737e0ad59a622b70094d6a08" + "d7db84998e3ad6bc8e187e465902818100b129d00f1829badfaab949c99e" + "7c55922434ae408924724a6e84d6f2d3de629d4a891b9ccf3a13baacda96" + "a898690c5a4be4617ee76d1283af8a99b8882f9bd568b17114488d3615be" + "d493ef35135ec0f3da7c24ea65df286f6365b06738f6bba63f41c45e667b" + "2ad3a6fe3f305aef57c3f35a56fb66df0515cc56e060e4162102818100be" + "4f698e7093b142295bb10c6950cf6b8129b1f0160b4796b65481093e5372" + "1bbe1cb7f1cca189c3e536596357a363514cf7a71e8ae5192c559c3b28cf" + "76303412feba75e342343d81e0a63b178545a21fb6fe55e75182d6038fb2" + "26f739de258dbcbfbc816ffbf3f0c327c6b648fc8f3a14ada8b05038559f" + "c441b2f94ed3210281800f9c765e699b561c54d85c7c66f1ce337be4c2c6" + "92af507cbeb3c6a588bf8c54b23c989301a0ab06b331e3668b92af860a5f" + "50cf00f2d2508b15e6abfa1e9bc303b4ea2f990dc8d52ea360aed5f30cb1" + "caa823ff6a1eb4f851223be1b7c21a1ee14a4248efeb8db8bdda2e92ea0c" + "de8e411ecda4caf87d92e4ad84163d5d99cd"_u8v}}, + {}); + RsaCheck( + "3072"sv, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MIIBojANBgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEAnvafMCYCGPGOhpvpA3FG\n" + "8EtN6P/wIQv+fC3MuqrOo3JBYyYVGIxMrpxX/iGEDOEDWkq9Lp7ANFKMu2P2cR+O\n" + "udZ+7WvIB8nINcvPZCGdqndOL9PoJV2zJveQ6FVMZpbh7Nc+dj5Mc2WFiE1LwkEe\n" + "aZD4c5r5UsT2CUvfdDTUoRqNnYVAjnQWfurrfo/o9gjietrvvGKT/dOtdDh/WvEl\n" + "3+HRwVrEeffsO/1tEPbAPHHvtnd1gTNJbNtzVAAUsHT7a4OxPK7JS3hmtv6JRp/Z\n" + "j9VUkDie86a1nD7dFW+S7a3N0W/0EMBiWJduiYye1Qitf2Uf0Dpo78J8lnJce7zR\n" + "1UwQc6EIuNlZh4EODIzz+Pm39locuuFgVVq38dcNStCeSX03GuL75SN3sBT6vXms\n" + "PeEhjhQxxlkGICbfuhCk1TPHj2Q8UROUAvZzbgspbhrtYYaVGrOc+eWBsqlgRzzV\n" + "bF//cJ9yRZy0PzsQSVTonAGpNXLHKTpB438hz5auwfNfAgMBAAE=\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "308201a2300d06092a864886f70d01010105000382018f003082018a0282" + "0181009ef69f30260218f18e869be9037146f04b4de8fff0210bfe7c2dcc" + "baaacea37241632615188c4cae9c57fe21840ce1035a4abd2e9ec034528c" + "bb63f6711f8eb9d67eed6bc807c9c835cbcf64219daa774e2fd3e8255db3" + "26f790e8554c6696e1ecd73e763e4c736585884d4bc2411e6990f8739af9" + "52c4f6094bdf7434d4a11a8d9d85408e74167eeaeb7e8fe8f608e27adaef" + "bc6293fdd3ad74387f5af125dfe1d1c15ac479f7ec3bfd6d10f6c03c71ef" + "b677758133496cdb73540014b074fb6b83b13caec94b7866b6fe89469fd9" + "8fd55490389ef3a6b59c3edd156f92edadcdd16ff410c06258976e898c9e" + "d508ad7f651fd03a68efc27c96725c7bbcd1d54c1073a108b8d95987810e" + "0c8cf3f8f9b7f65a1cbae160555ab7f1d70d4ad09e497d371ae2fbe52377" + "b014fabd79ac3de1218e1431c659062026dfba10a4d533c78f643c511394" + "02f6736e0b296e1aed6186951ab39cf9e581b2a960473cd56c5fff709f72" + "459cb43f3b104954e89c01a93572c7293a41e37f21cf96aec1f35f020301" + "0001"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIIG/AIBADANBgkqhkiG9w0BAQEFAASCBuYwggbiAgEAAoIBgQCe9p8wJgIY8Y6G\n" + "m+kDcUbwS03o//AhC/58Lcy6qs6jckFjJhUYjEyunFf+IYQM4QNaSr0unsA0Uoy7\n" + "Y/ZxH4651n7ta8gHycg1y89kIZ2qd04v0+glXbMm95DoVUxmluHs1z52PkxzZYWI\n" + "TUvCQR5pkPhzmvlSxPYJS990NNShGo2dhUCOdBZ+6ut+j+j2COJ62u+8YpP90610\n" + "OH9a8SXf4dHBWsR59+w7/W0Q9sA8ce+2d3WBM0ls23NUABSwdPtrg7E8rslLeGa2\n" + "/olGn9mP1VSQOJ7zprWcPt0Vb5Ltrc3Rb/QQwGJYl26JjJ7VCK1/ZR/QOmjvwnyW\n" + "clx7vNHVTBBzoQi42VmHgQ4MjPP4+bf2Why64WBVWrfx1w1K0J5JfTca4vvlI3ew\n" + "FPq9eaw94SGOFDHGWQYgJt+6EKTVM8ePZDxRE5QC9nNuCyluGu1hhpUas5z55YGy\n" + "qWBHPNVsX/9wn3JFnLQ/OxBJVOicAak1cscpOkHjfyHPlq7B818CAwEAAQKCAYBK\n" + "w+QLWVUTNkm6tgnaPKUIz+JM/FOMt39yGHh6M2wNI+ftIjQ534MRfSdFt63MAOj6\n" + "xrxD+RadhVX7rQB0JEuUzHXWZSMnxpgL9VgN2GG3k3WKuTgumutwIHBfVf8hIUYR\n" + "hwsxwgtjGxS7Dt/a9ZXAQRcaCIHLlCfEJ5NprI91Vm/U7p92YNNTzloEpNsFHRio\n" + "f+DR0euZLr4eM5RyyYjuy99D+dT/KMRLUt7BY8z2oQAF6hmyMtUOBgkwMPmKJPp9\n" + "xBLHH/25jcY9t3+PM9vT3xkrJ3GrcF9kHHd4R03lLdxuvf0B/AvHlZUphwYcf6ET\n" + "eKS1H9t1CII2Z0ZKHlv2C4Zy3ZH06yIFPNjJI9wCJcq9QnH1SN3Kf1Jak7iXT0xC\n" + "rOrUE4N2mcp8nF3HhgeZZ0vPCW/NBvwHmZgdDIrZI7m5aGq+MZr2JOqrcHgVDMC6\n" + "KCvMpkJRN8Ine6ICpBKd4tl79mrRBbXWK/hCEIwSG9o7C8CPV6NrnZMfgVMUqgEC\n" + "gcEAzWoCtDSsn08CUivWoLdwcvXkCXWihRYM45EwSY3Zkz1YLsjp7NNGoX0mOzWW\n" + "aDa6Ofm2KSQv2pBw3oVcR7ns/IUg/2GSP+v1aZpWGx7wPUu3dbuMfpnt8PKUazbL\n" + "JoBF83bMXZN5LKijEW4HWqO5WsjUwwLUJm0ouvSabnMof6sCeviiDe38Be2lXGfJ\n" + "jCsj9hvmGUqPdFCQJ9c6IjXaP20qvOYNDdVYJ5AJLZywfH7RDGHd4+GH5GToHXTr\n" + "9h7hAoHBAMYcM7jj2NKaNkEYnf7YTnjNymdDYLwm7UCFDKHmluYzleNHuyQu53Y+\n" + "hSyNhKRRVtdxuSKybDPKzTSw5gN4YP1kCGITcDTIQkt9W70Z/U4JL5WSaTyFelZY\n" + "Wg8MACPZhIDexZ+f+bNA82VBinE40kTXAsp+I/dNZHusZJnHR0Q0bnarfCz8UuQe\n" + "qFO4vVH71MpJvthVD/wHuhHZT/NpRXDwgrkKVOC8VwK5U2ZTtVf8uppdsBr0gmyX\n" + "FHI+kz2aPwKBwQCF1E2SrsbQvB8c/ibFav4+R+mcKCIMZ0NaeFtncJ2SimMLiCav\n" + "/y6DRBBGfzFREGbgIssFnuf2lCiVMXnf2UiHdQz8lcs9DjRD6yOyY8PNi6kpcVml\n" + "mhAl7UW5XGea2/O3HW0kglJuQCiN0IvGB+lZNoM30n350yC4PWjoEOsP0pC5IYgj\n" + "XyvViPE1dQEg63Jwg9i0HZm9BEgHTPg5FbDtpeg0TgWvP5JBpFv2daGeWtlEIfb4\n" + "4xUwPnXjyyt4nMECgb9cFr/0MfWX8BdIKylGTUYs4Xw0hB1zWKTwWOiGWanLWC9U\n" + "dwOGzkbJsEY3b5E40JaNj09/0XB6osrAs3o4IrzzDIzZCjAeWPh4Hs2GGY6lt59m\n" + "56gDeghkGq3CUNG/2Fy/is5SZQqtSIPbjZvNBZy4Yzno5rnROyh6VKhu0zNNgRHY\n" + "F96hCql9YMLeKAHZGjbP0XflF6VWgkD8CwgfHdApr6MUYLkTvnizy3H5HvAs9k3H\n" + "c8Vowj/eOlxGvs+y0wKBwHOIGyQ49ethiJMUyt9orVfV04luQgvOOi2VRu8ZQesp\n" + "hA3tuBVilKq7G9EnH2EHy9/xpOql4SzqvboCivh4oKf0W8UyAKQ3Yv7/h4T/obRY\n" + "QzPUB3vc657B+5DzQEhjnP88BF6poznANEmuK9YJ4OXeA3RQaQiIEIFbY+GgaO7u\n" + "znZv+HE2UJcfnbGW3m/naIKfIhcIaq2bfWOG5hYd1L6Dazkfttc2GtRYaBpyo0bz\n" + "gpvbM5ETOLFocm1MM1hcOA==\n" + "-----END PRIVATE KEY-----\n"_u8}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "308206fc020100300d06092a864886f70d0101010500048206e6308206e2" + "02010002820181009ef69f30260218f18e869be9037146f04b4de8fff021" + "0bfe7c2dccbaaacea37241632615188c4cae9c57fe21840ce1035a4abd2e" + "9ec034528cbb63f6711f8eb9d67eed6bc807c9c835cbcf64219daa774e2f" + "d3e8255db326f790e8554c6696e1ecd73e763e4c736585884d4bc2411e69" + "90f8739af952c4f6094bdf7434d4a11a8d9d85408e74167eeaeb7e8fe8f6" + "08e27adaefbc6293fdd3ad74387f5af125dfe1d1c15ac479f7ec3bfd6d10" + "f6c03c71efb677758133496cdb73540014b074fb6b83b13caec94b7866b6" + "fe89469fd98fd55490389ef3a6b59c3edd156f92edadcdd16ff410c06258" + "976e898c9ed508ad7f651fd03a68efc27c96725c7bbcd1d54c1073a108b8" + "d95987810e0c8cf3f8f9b7f65a1cbae160555ab7f1d70d4ad09e497d371a" + "e2fbe52377b014fabd79ac3de1218e1431c659062026dfba10a4d533c78f" + "643c51139402f6736e0b296e1aed6186951ab39cf9e581b2a960473cd56c" + "5fff709f72459cb43f3b104954e89c01a93572c7293a41e37f21cf96aec1" + "f35f0203010001028201804ac3e40b5955133649bab609da3ca508cfe24c" + "fc538cb77f7218787a336c0d23e7ed223439df83117d2745b7adcc00e8fa" + "c6bc43f9169d8555fbad0074244b94cc75d6652327c6980bf5580dd861b7" + "93758ab9382e9aeb7020705f55ff21214611870b31c20b631b14bb0edfda" + "f595c041171a0881cb9427c4279369ac8f75566fd4ee9f7660d353ce5a04" + "a4db051d18a87fe0d1d1eb992ebe1e339472c988eecbdf43f9d4ff28c44b" + "52dec163ccf6a10005ea19b232d50e06093030f98a24fa7dc412c71ffdb9" + "8dc63db77f8f33dbd3df192b2771ab705f641c7778474de52ddc6ebdfd01" + "fc0bc795952987061c7fa11378a4b51fdb7508823667464a1e5bf60b8672" + "dd91f4eb22053cd8c923dc0225cabd4271f548ddca7f525a93b8974f4c42" + "acead413837699ca7c9c5dc7860799674bcf096fcd06fc0799981d0c8ad9" + "23b9b9686abe319af624eaab7078150cc0ba282bcca6425137c2277ba202" + "a4129de2d97bf66ad105b5d62bf842108c121bda3b0bc08f57a36b9d931f" + "815314aa010281c100cd6a02b434ac9f4f02522bd6a0b77072f5e40975a2" + "85160ce39130498dd9933d582ec8e9ecd346a17d263b35966836ba39f9b6" + "29242fda9070de855c47b9ecfc8520ff61923febf5699a561b1ef03d4bb7" + "75bb8c7e99edf0f2946b36cb268045f376cc5d93792ca8a3116e075aa3b9" + "5ac8d4c302d4266d28baf49a6e73287fab027af8a20dedfc05eda55c67c9" + "8c2b23f61be6194a8f74509027d73a2235da3f6d2abce60d0dd558279009" + "2d9cb07c7ed10c61dde3e187e464e81d74ebf61ee10281c100c61c33b8e3" + "d8d29a3641189dfed84e78cdca674360bc26ed40850ca1e696e63395e347" + "bb242ee7763e852c8d84a45156d771b922b26c33cacd34b0e6037860fd64" + "0862137034c8424b7d5bbd19fd4e092f9592693c857a56585a0f0c0023d9" + "8480dec59f9ff9b340f365418a7138d244d702ca7e23f74d647bac6499c7" + "4744346e76ab7c2cfc52e41ea853b8bd51fbd4ca49bed8550ffc07ba11d9" + "4ff3694570f082b90a54e0bc5702b9536653b557fcba9a5db01af4826c97" + "14723e933d9a3f0281c10085d44d92aec6d0bc1f1cfe26c56afe3e47e99c" + "28220c67435a785b67709d928a630b8826afff2e834410467f31511066e0" + "22cb059ee7f69428953179dfd94887750cfc95cb3d0e3443eb23b263c3cd" + "8ba9297159a59a1025ed45b95c679adbf3b71d6d2482526e40288dd08bc6" + "07e959368337d27df9d320b83d68e810eb0fd290b92188235f2bd588f135" + "750120eb727083d8b41d99bd0448074cf83915b0eda5e8344e05af3f9241" + "a45bf675a19e5ad94421f6f8e315303e75e3cb2b789cc10281bf5c16bff4" + "31f597f017482b29464d462ce17c34841d7358a4f058e88659a9cb582f54" + "770386ce46c9b046376f9138d0968d8f4f7fd1707aa2cac0b37a3822bcf3" + "0c8cd90a301e58f8781ecd86198ea5b79f66e7a8037a08641aadc250d1bf" + "d85cbf8ace52650aad4883db8d9bcd059cb86339e8e6b9d13b287a54a86e" + "d3334d8111d817dea10aa97d60c2de2801d91a36cfd177e517a5568240fc" + "0b081f1dd029afa31460b913be78b3cb71f91ef02cf64dc773c568c23fde" + "3a5c46becfb2d30281c073881b2438f5eb61889314cadf68ad57d5d3896e" + "420bce3a2d9546ef1941eb29840dedb8156294aabb1bd1271f6107cbdff1" + "a4eaa5e12ceabdba028af878a0a7f45bc53200a43762feff8784ffa1b458" + "4333d4077bdceb9ec1fb90f34048639cff3c045ea9a339c03449ae2bd609" + "e0e5de03745069088810815b63e1a068eeeece766ff8713650971f9db196" + "de6fe768829f2217086aad9b7d6386e6161dd4be836b391fb6d7361ad458" + "681a72a346f3829bdb33911338b168726d4c33585c38"_u8v}}, + {}); + RsaCheck( + "4096"sv, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEArcIlHK14T1jiaKxiiS3E\n" + "jOcQvxErqnc/aQLPXNIx/d21D1/Ii1vODoSTt5KAvWE87ah0rE5oAs0wMVRWSAV+\n" + "evUAG+WVjBiN8crGilTVWyylCNVBM2Kwrt95rDYLz68OL7ZOgJ393KfZ2HbMq51Q\n" + "d6AJsvTkEGVdbbO3q9Opx1htoVm4Crac3TgKpNQceQ6i7b3gFBMaIjGNhD8fuWgK\n" + "5FAYQgstzURXB5B/vzgOSgkG70q1Ksr8iq5q5UpQRoVaUwq7TXvFR0hARoE0lS7T\n" + "WeAA5GksGwLYPDfg7WvZn1LmCLU0OAAdPJAegFOZa87rKZTo8PXWMgbQM15/rz1G\n" + "DFVbJRToGr1pgO0lNZ19uBFfmStcQowHNR1is9GCv08yb6ZQBszA3JhGet2Aaqy9\n" + "HgN6nS8bzJnYljz3q5MTwzLvtBYl7Bw4MBDypbPnqHeGyUY1LJ8BkRbX0b39FPbl\n" + "9ywWXDnjJ8sly1D3FQmtOiN54grrYW8Ee9kTPQchdgogOX1pJRTRuEnGJI5aRiDE\n" + "gVkrx12jSymdeUpuy7UdvPkWKQYzwgABgNc7RZs6+qWkQZBsU96jSxMayKTsP5U6\n" + "fT2rkSAsmXH6HPjtD2gvpbvl32MUBGlbNOUYh4gePv01GiztuONCjASfhRvjJagK\n" + "4zn18PfVWWpgqas4rrFDyLECAwEAAQ==\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "30820222300d06092a864886f70d01010105000382020f003082020a0282" + "020100adc2251cad784f58e268ac62892dc48ce710bf112baa773f6902cf" + "5cd231fdddb50f5fc88b5bce0e8493b79280bd613ceda874ac4e6802cd30" + "31545648057e7af5001be5958c188df1cac68a54d55b2ca508d5413362b0" + "aedf79ac360bcfaf0e2fb64e809dfddca7d9d876ccab9d5077a009b2f4e4" + "10655d6db3b7abd3a9c7586da159b80ab69cdd380aa4d41c790ea2edbde0" + "14131a22318d843f1fb9680ae45018420b2dcd445707907fbf380e4a0906" + "ef4ab52acafc8aae6ae54a5046855a530abb4d7bc5474840468134952ed3" + "59e000e4692c1b02d83c37e0ed6bd99f52e608b53438001d3c901e805399" + "6bceeb2994e8f0f5d63206d0335e7faf3d460c555b2514e81abd6980ed25" + "359d7db8115f992b5c428c07351d62b3d182bf4f326fa65006ccc0dc9846" + "7add806aacbd1e037a9d2f1bcc99d8963cf7ab9313c332efb41625ec1c38" + "3010f2a5b3e7a87786c946352c9f019116d7d1bdfd14f6e5f72c165c39e3" + "27cb25cb50f71509ad3a2379e20aeb616f047bd9133d0721760a20397d69" + "2514d1b849c6248e5a4620c481592bc75da34b299d794a6ecbb51dbcf916" + "290633c2000180d73b459b3afaa5a441906c53dea34b131ac8a4ec3f953a" + "7d3dab91202c9971fa1cf8ed0f682fa5bbe5df631404695b34e51887881e" + "3efd351a2cedb8e3428c049f851be325a80ae339f5f0f7d5596a60a9ab38" + "aeb143c8b10203010001"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCtwiUcrXhPWOJo\n" + "rGKJLcSM5xC/ESuqdz9pAs9c0jH93bUPX8iLW84OhJO3koC9YTztqHSsTmgCzTAx\n" + "VFZIBX569QAb5ZWMGI3xysaKVNVbLKUI1UEzYrCu33msNgvPrw4vtk6Anf3cp9nY\n" + "dsyrnVB3oAmy9OQQZV1ts7er06nHWG2hWbgKtpzdOAqk1Bx5DqLtveAUExoiMY2E\n" + "Px+5aArkUBhCCy3NRFcHkH+/OA5KCQbvSrUqyvyKrmrlSlBGhVpTCrtNe8VHSEBG\n" + "gTSVLtNZ4ADkaSwbAtg8N+Dta9mfUuYItTQ4AB08kB6AU5lrzusplOjw9dYyBtAz\n" + "Xn+vPUYMVVslFOgavWmA7SU1nX24EV+ZK1xCjAc1HWKz0YK/TzJvplAGzMDcmEZ6\n" + "3YBqrL0eA3qdLxvMmdiWPPerkxPDMu+0FiXsHDgwEPKls+eod4bJRjUsnwGRFtfR\n" + "vf0U9uX3LBZcOeMnyyXLUPcVCa06I3niCuthbwR72RM9ByF2CiA5fWklFNG4ScYk\n" + "jlpGIMSBWSvHXaNLKZ15Sm7LtR28+RYpBjPCAAGA1ztFmzr6paRBkGxT3qNLExrI\n" + "pOw/lTp9PauRICyZcfoc+O0PaC+lu+XfYxQEaVs05RiHiB4+/TUaLO2440KMBJ+F\n" + "G+MlqArjOfXw99VZamCpqziusUPIsQIDAQABAoICAQCfy0GyA932qrlcpdvgaBSv\n" + "t/fwnuvXUt8fxZPJuwx6eR//yYh2kLEJLOdkFPkMMJaFwTu7EkgY+3ZshzDp/xN4\n" + "JEQ7Y4GKWzJ+wIqhwK6NsJr9apERnpr5107gDrwB/O1A95luMt25xStUJLzIvl24\n" + "BZel2gy6/11Se8pX3MnwJ+R6VDYqtBHCZ71yJBcjRVCU7t9Z1s9bztJkYmDcc1BA\n" + "81+7rOgsM8MNk9fHlNefQnn8Kmo9tntVVl28DAGTOSP95oqmEUM18L4bmMswvuVj\n" + "a9umMwp6tL0DdCgIb/ysxuIB9BLXxVMd1TQXs8oOGTavAODQaGTZkOZ7t1YZZHI7\n" + "c8rJwOycVKDE9Y0prtI5oyyUF4kbOkICO7aHiOaXRX1JVsUVHV2ck5/UQfAyq2XZ\n" + "/Q99j8rHJYHmCOBoy7ECQ87SStGU/ypyl3PC31af8lTEQpN+aj/B2yq2NaaEfouw\n" + "fW9kNNMkfizJQ6aB4uYya63vZBdAYlZ8T3KbgiEFTPO6Am03BipReChDtfQPlNlx\n" + "eKcYxKXwEp+ddDu+fNmXowpJeiuFbc3/tSS8WvfMQDmVu0ikOG5MfDy1tL22F42w\n" + "lhNnVtt1lwN2op3/WfyfyroNqMA+pT2BqnjzAH/cAhKnI6DU+8bTG8jmPpMnftXH\n" + "k91OUUjucDe+pLLJ0gdE+QKCAQEA56Wq/2x8drGlAYhoRmS8lJldtU2zGUCXnBHk\n" + "roZVOkRvG0bDZTokBkn7ZHhH8/uCIjIUaNJRbpFDzlbtShxYphzunIrJq8OhSoAE\n" + "vKOqVII5rDQ9WlZvIIZ46SQcg6uFbA6l95hi2WPwjl5Enzc3sw1i6fEv4eziPWFa\n" + "JEHlZaonuIEjo8uHKZBhgG3nW5w5oYcPzTOM3ogca2DM1xa1hBzJi0P28Zn+NJBt\n" + "JPnfZiyHlzsXdAtlJtvTvpVknIpwnjfL8vM9RCk1+EfWXGqgrn/pI1IlfQn1ZPRd\n" + "/xJ/+qwZN2StReHjU3Z+KjNR+roALpBV/AhkPd36rmhoXqHPTwKCAQEAwAaDfT2U\n" + "6MzV0cudB1Q1UxrdFqsyxycIWWkxenn/LdOULecN0ujoW8TSTz4PWsXRDAhtdRS0\n" + "nsg4XalKb9UYnqzohNQcqv0rJR/M1HaZ4coUACZcICfiAdFWP1SiFZgu9PJW/EUJ\n" + "u2z08t35iOGwGLxfHFPoiZaomAVlyL+kkEXA8AWwms5QCZjf+IZuT/RbUK4qgzcR\n" + "xbYsv0cjmymCObrFNQjjsedQDygOmj0rA8X1Q/tNfKsx86YJV2CsErWJW4RUZP/h\n" + "Ws9kR40l2NvJnWZEHsnIRNPFgJ90NavXvUwY4VEt0D9An9jiWXLXF5NpChFiWB6/\n" + "8EvgwiydQhrn/wKCAQEAyk5xbOm+OZsj1JbhGrlXyR+4K2NUizVSM0edRJ6lSGID\n" + "9vpyI7IHTEbIexJhJL//AwZhtLoZzEqpwUdBrXvcIBccfTLotk4ASyRK/sShOXUS\n" + "EUb+Xismmm1Wo6aaEJR3zcttPzOjAOC7clr562M6DfIe9NljTBip7ZlcNFYolgVo\n" + "80Y1bhOOU8p4nMVfTS6/Vkayki/3U1HkIBNGUoLOvDa3/hy5Sn+G9zk7WROw+3bg\n" + "ZD+DWCGrkahi4Qtv9xchC80HHYM5epHTRKbYm5W0BzJG1kYj33QXELgqb14kzzQG\n" + "Qc53VZTWCEpwHUL80dAn4ILF1XsusKlxCWi93gfLGQKCAQBwhCCNxQS4+DUdjgI/\n" + "5h6syGPdwYiqWvuwcEv2qP9V2dDMqMNX3vMvun9EwWd718drFpEUdoJzO3yTnPup\n" + "1aJsb4J7OlJl+pxKT3zUzX3TaHYZtGBs0xHB4Oh5iVzD7H0vN8SyYr2WHfzVRi3N\n" + "//gQNmhAkAYEgMve7+K5I1oI02Z+/caCnvsU9Iff9t0yaksLVlJAuobmY52KouOB\n" + "KmxM6VxefAv3FUO67czIoajPuDHDmL/JmgJV8ucsVM/e0pJeloZg+/IPJNBsgI85\n" + "p2dWnDK0G6YGdlQWztfoDv4FxE4b0FZY3IdAYnQW14yjGtQEezU1zybGZZ+YB05K\n" + "Crv/AoIBADlvlifpD0dkIQ2udfG7wUi0OXmj9sepQ73k0n8ULzX1PV3oMyknGs54\n" + "0KlSKGgg/47YezvI4BupBZ0zVox0Ztgilg2oHdQlmEAGGiDBYNLO/Nd9FotytcAu\n" + "gjyanngI5IaXzKVEAW6QFZTkeIavl/S44NtB39MjM7tRaaJNu60PaZ/UOlicyyNj\n" + "RZPS2JrYmymSo9La27si9yY9L/gXCw0yL2/3sbR5cPEkoFN0o9XGDs5BWsnvMLzH\n" + "UwtbtmC2Pw4kp8kzRk21cH75Yl/wg9Oir95uL0C8w4B7FacCPijNPCPsvmYJi9Cg\n" + "Kah5KuUlkoLBFrMqKJCKqhvP72HZUp8=\n" + "-----END PRIVATE KEY-----\n"_u8}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "30820943020100300d06092a864886f70d01010105000482092d30820929" + "0201000282020100adc2251cad784f58e268ac62892dc48ce710bf112baa" + "773f6902cf5cd231fdddb50f5fc88b5bce0e8493b79280bd613ceda874ac" + "4e6802cd3031545648057e7af5001be5958c188df1cac68a54d55b2ca508" + "d5413362b0aedf79ac360bcfaf0e2fb64e809dfddca7d9d876ccab9d5077" + "a009b2f4e410655d6db3b7abd3a9c7586da159b80ab69cdd380aa4d41c79" + "0ea2edbde014131a22318d843f1fb9680ae45018420b2dcd445707907fbf" + "380e4a0906ef4ab52acafc8aae6ae54a5046855a530abb4d7bc547484046" + "8134952ed359e000e4692c1b02d83c37e0ed6bd99f52e608b53438001d3c" + "901e8053996bceeb2994e8f0f5d63206d0335e7faf3d460c555b2514e81a" + "bd6980ed25359d7db8115f992b5c428c07351d62b3d182bf4f326fa65006" + "ccc0dc98467add806aacbd1e037a9d2f1bcc99d8963cf7ab9313c332efb4" + "1625ec1c383010f2a5b3e7a87786c946352c9f019116d7d1bdfd14f6e5f7" + "2c165c39e327cb25cb50f71509ad3a2379e20aeb616f047bd9133d072176" + "0a20397d692514d1b849c6248e5a4620c481592bc75da34b299d794a6ecb" + "b51dbcf916290633c2000180d73b459b3afaa5a441906c53dea34b131ac8" + "a4ec3f953a7d3dab91202c9971fa1cf8ed0f682fa5bbe5df631404695b34" + "e51887881e3efd351a2cedb8e3428c049f851be325a80ae339f5f0f7d559" + "6a60a9ab38aeb143c8b1020301000102820201009fcb41b203ddf6aab95c" + "a5dbe06814afb7f7f09eebd752df1fc593c9bb0c7a791fffc9887690b109" + "2ce76414f90c309685c13bbb124818fb766c8730e9ff137824443b63818a" + "5b327ec08aa1c0ae8db09afd6a91119e9af9d74ee00ebc01fced40f7996e" + "32ddb9c52b5424bcc8be5db80597a5da0cbaff5d527bca57dcc9f027e47a" + "54362ab411c267bd72241723455094eedf59d6cf5bced2646260dc735040" + "f35fbbace82c33c30d93d7c794d79f4279fc2a6a3db67b55565dbc0c0193" + "3923fde68aa6114335f0be1b98cb30bee5636bdba6330a7ab4bd03742808" + "6ffcacc6e201f412d7c5531dd53417b3ca0e1936af00e0d06864d990e67b" + "b7561964723b73cac9c0ec9c54a0c4f58d29aed239a32c9417891b3a4202" + "3bb68788e697457d4956c5151d5d9c939fd441f032ab65d9fd0f7d8fcac7" + "2581e608e068cbb10243ced24ad194ff2a729773c2df569ff254c442937e" + "6a3fc1db2ab635a6847e8bb07d6f6434d3247e2cc943a681e2e6326badef" + "64174062567c4f729b8221054cf3ba026d37062a51782843b5f40f94d971" + "78a718c4a5f0129f9d743bbe7cd997a30a497a2b856dcdffb524bc5af7cc" + "403995bb48a4386e4c7c3cb5b4bdb6178db096136756db75970376a29dff" + "59fc9fcaba0da8c03ea53d81aa78f3007fdc0212a723a0d4fbc6d31bc8e6" + "3e93277ed5c793dd4e5148ee7037bea4b2c9d20744f90282010100e7a5aa" + "ff6c7c76b1a50188684664bc94995db54db31940979c11e4ae86553a446f" + "1b46c3653a240649fb647847f3fb8222321468d2516e9143ce56ed4a1c58" + "a61cee9c8ac9abc3a14a8004bca3aa548239ac343d5a566f208678e9241c" + "83ab856c0ea5f79862d963f08e5e449f3737b30d62e9f12fe1ece23d615a" + "2441e565aa27b88123a3cb87299061806de75b9c39a1870fcd338cde881c" + "6b60ccd716b5841cc98b43f6f199fe34906d24f9df662c87973b17740b65" + "26dbd3be95649c8a709e37cbf2f33d442935f847d65c6aa0ae7fe9235225" + "7d09f564f45dff127ffaac193764ad45e1e353767e2a3351faba002e9055" + "fc08643dddfaae68685ea1cf4f0282010100c006837d3d94e8ccd5d1cb9d" + "075435531add16ab32c727085969317a79ff2dd3942de70dd2e8e85bc4d2" + "4f3e0f5ac5d10c086d7514b49ec8385da94a6fd5189eace884d41caafd2b" + "251fccd47699e1ca1400265c2027e201d1563f54a215982ef4f256fc4509" + "bb6cf4f2ddf988e1b018bc5f1c53e88996a8980565c8bfa49045c0f005b0" + "9ace500998dff8866e4ff45b50ae2a833711c5b62cbf47239b298239bac5" + "3508e3b1e7500f280e9a3d2b03c5f543fb4d7cab31f3a6095760ac12b589" + "5b845464ffe15acf64478d25d8dbc99d66441ec9c844d3c5809f7435abd7" + "bd4c18e1512dd03f409fd8e25972d71793690a1162581ebff04be0c22c9d" + "421ae7ff0282010100ca4e716ce9be399b23d496e11ab957c91fb82b6354" + "8b355233479d449ea5486203f6fa7223b2074c46c87b126124bfff030661" + "b4ba19cc4aa9c14741ad7bdc20171c7d32e8b64e004b244afec4a1397512" + "1146fe5e2b269a6d56a3a69a109477cdcb6d3f33a300e0bb725af9eb633a" + "0df21ef4d9634c18a9ed995c345628960568f346356e138e53ca789cc55f" + "4d2ebf5646b2922ff75351e42013465282cebc36b7fe1cb94a7f86f7393b" + "5913b0fb76e0643f835821ab91a862e10b6ff717210bcd071d83397a91d3" + "44a6d89b95b4073246d64623df741710b82a6f5e24cf340641ce775594d6" + "084a701d42fcd1d027e082c5d57b2eb0a9710968bdde07cb190282010070" + "84208dc504b8f8351d8e023fe61eacc863ddc188aa5afbb0704bf6a8ff55" + "d9d0cca8c357def32fba7f44c1677bd7c76b1691147682733b7c939cfba9" + "d5a26c6f827b3a5265fa9c4a4f7cd4cd7dd3687619b4606cd311c1e0e879" + "895cc3ec7d2f37c4b262bd961dfcd5462dcdfff81036684090060480cbde" + "efe2b9235a08d3667efdc6829efb14f487dff6dd326a4b0b565240ba86e6" + "639d8aa2e3812a6c4ce95c5e7c0bf71543baedccc8a1a8cfb831c398bfc9" + "9a0255f2e72c54cfded2925e968660fbf20f24d06c808f39a767569c32b4" + "1ba606765416ced7e80efe05c44e1bd05658dc8740627416d78ca31ad404" + "7b3535cf26c6659f98074e4a0abbff02820100396f9627e90f4764210dae" + "75f1bbc148b43979a3f6c7a943bde4d27f142f35f53d5de83329271ace78" + "d0a952286820ff8ed87b3bc8e01ba9059d33568c7466d822960da81dd425" + "9840061a20c160d2cefcd77d168b72b5c02e823c9a9e7808e48697cca544" + "016e901594e47886af97f4b8e0db41dfd32333bb5169a24dbbad0f699fd4" + "3a589ccb23634593d2d89ad89b2992a3d2dadbbb22f7263d2ff8170b0d32" + "2f6ff7b1b47970f124a05374a3d5c60ece415ac9ef30bcc7530b5bb660b6" + "3f0e24a7c933464db5707ef9625ff083d3a2afde6e2f40bcc3807b15a702" + "3e28cd3c23ecbe66098bd0a029a8792ae5259282c116b32a28908aaa1bcf" + "ef61d9529f"_u8v}}, + {}); + + ManagedNegativeCheck(1, __WASI_ALGORITHM_TYPE_SYMMETRIC, "Ed25519"sv, + std::nullopt, __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + ManagedNegativeCheck(1, __WASI_ALGORITHM_TYPE_SIGNATURES, "FooBar"sv, + std::nullopt, __WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); + ManagedNegativeCheck(1, __WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv, + static_cast<__wasi_options_t>(InvaildHandle), + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/common.cpp b/test/plugins/wasi_crypto/common.cpp new file mode 100644 index 00000000..e8310c33 --- /dev/null +++ b/test/plugins/wasi_crypto/common.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/func.h" +#include "helper.h" + +namespace { +template +inline T *getHostFunc(const M &Mod, const char *Name) { + if (Mod) { + auto *FuncInst = Mod->findFuncExports(Name); + if (FuncInst && FuncInst->isHostFunction()) { + return dynamic_cast(&FuncInst->getHostFunc()); + } + } + return nullptr; +} +} // namespace + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Options) { + // Symmetric options. + { + // Open options. + WASI_CRYPTO_EXPECT_SUCCESS(SymmetricOptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_SYMMETRIC)); + + // Set options. + WASI_CRYPTO_EXPECT_TRUE( + optionsSet(SymmetricOptionsHandle, "context"sv, "foo"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSet(SymmetricOptionsHandle, "salt"sv, "foo"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSet(SymmetricOptionsHandle, "nonce"sv, "foo"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSetU64(SymmetricOptionsHandle, "memory_limit"sv, 0)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSetU64(SymmetricOptionsHandle, "ops_limit"sv, 0)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSetU64(SymmetricOptionsHandle, "parallelism"sv, 0)); + + // Unsupported options. + WASI_CRYPTO_EXPECT_FAILURE( + optionsSet(SymmetricOptionsHandle, "foo"sv, "foo"_u8), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + optionsSetU64(SymmetricOptionsHandle, "foo"sv, 0), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + writeDummyMemoryContent(); + writeString("foo"sv, 0); + uint32_t NameSize = 3; + auto *Func = getHostFunc( + WasiCryptoCommonMod, "options_set_guest_buffer"); + ASSERT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SymmetricOptionsHandle, 0, NameSize, 0, NameSize}, + Errno)); + EXPECT_EQ(Errno[0].get(), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + WASI_CRYPTO_EXPECT_TRUE(optionsClose(SymmetricOptionsHandle)); + } + + // Signature options. + { + // Open options. + WASI_CRYPTO_EXPECT_SUCCESS(SigOptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_SIGNATURES)); + + // Unsupported options. + WASI_CRYPTO_EXPECT_FAILURE(optionsSet(SigOptionsHandle, "foo"sv, "foo"_u8), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + WASI_CRYPTO_EXPECT_FAILURE(optionsSetU64(SigOptionsHandle, "foo"sv, 0), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + writeDummyMemoryContent(); + writeString("foo"sv, 0); + uint32_t NameSize = 3; + auto *Func = getHostFunc( + WasiCryptoCommonMod, "options_set_guest_buffer"); + ASSERT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SigOptionsHandle, 0, NameSize, 0, NameSize}, + Errno)); + EXPECT_EQ(Errno[0].get(), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + // Close options. + WASI_CRYPTO_EXPECT_TRUE(optionsClose(SigOptionsHandle)); + } + + // Key exchange options. + { + // Open options. + WASI_CRYPTO_EXPECT_SUCCESS(KxOptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE)); + // Unsupported options. + WASI_CRYPTO_EXPECT_FAILURE(optionsSet(KxOptionsHandle, "foo"sv, "foo"_u8), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + WASI_CRYPTO_EXPECT_FAILURE(optionsSetU64(KxOptionsHandle, "foo"sv, 0), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + writeDummyMemoryContent(); + writeString("foo"sv, 0); + uint32_t NameSize = 3; + auto *Func = getHostFunc( + WasiCryptoCommonMod, "options_set_guest_buffer"); + ASSERT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + KxOptionsHandle, 0, NameSize, 0, NameSize}, + Errno)); + EXPECT_EQ(Errno[0].get(), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + // Close options. + WASI_CRYPTO_EXPECT_TRUE(optionsClose(KxOptionsHandle)); + } +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/hash.cpp b/test/plugins/wasi_crypto/hash.cpp new file mode 100644 index 00000000..ac615978 --- /dev/null +++ b/test/plugins/wasi_crypto/hash.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Hash) { + auto HashTest = [this](std::string_view Name, + const std::vector &AbsorbData1, + const std::vector &AbsorbData2, + const std::vector &ExpectedSqueezeData1, + const std::vector &ExpectedSqueezeData2, + const std::vector &TruncatedSquueezeData) { + WASI_CRYPTO_EXPECT_SUCCESS( + StateHandle, symmetricStateOpen(Name, std::nullopt, std::nullopt)); + + SCOPED_TRACE(Name); + { + // "data" + std::vector SqueezeContent(ExpectedSqueezeData1.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData1)); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(StateHandle, SqueezeContent)); + EXPECT_EQ(SqueezeContent, ExpectedSqueezeData1); + } + + { + // "datamore_data" + std::vector SqueezeContent(ExpectedSqueezeData2.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData2)); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(StateHandle, SqueezeContent)); + EXPECT_EQ(SqueezeContent, ExpectedSqueezeData2); + } + + { + // Smaller than the hash function output size. Truncate the output. + std::vector SqueezeContent(TruncatedSquueezeData.size()); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(StateHandle, SqueezeContent)); + EXPECT_EQ(SqueezeContent, TruncatedSquueezeData); + } + + { + // Requested size exceeds the returned invalid_length. + std::vector SqueezeContent(ExpectedSqueezeData1.size() + 1); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateSqueeze(StateHandle, SqueezeContent), + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + } + + { + // Clone checking. + WASI_CRYPTO_EXPECT_SUCCESS(NewStateHandle, + symmetricStateClone(StateHandle)); + EXPECT_NE(StateHandle, NewStateHandle); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(NewStateHandle)); + } + + { + // Some error cases checking. + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyGenerate(Name, std::nullopt), + __WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOpen(Name, InvaildHandle, std::nullopt), + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGet(StateHandle, "foo"sv, {}), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGetU64(StateHandle, "foo"sv), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeTag(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeKey(StateHandle, Name), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateMaxTagLen(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateEncrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateEncryptDetached(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateDecrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateDecryptDetached(StateHandle, {}, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + // Close. + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(StateHandle)); + }; + + HashTest( + "SHA-256"sv, "data"_u8, "more_data"_u8, + "3a6eb0790f39ac87c94f3856b2dd2c5d110e6811602261a9a923d3bb23adc8b7"_u8v, + "13c40eec22541a155e172010c7fd6ef654e4e138a0c20923f9a91062a27f57b6"_u8v, + "13c40eec22541a155e172010c7fd6ef654e4e138a0c20923f9a91062a27f57"_u8v); + HashTest( + "SHA-512"sv, "data"_u8, "more_data"_u8, + "77c7ce9a5d86bb386d443bb96390faa120633158699c8844c30b13ab0bf92760b7e4416aea397db91b4ac0e5dd56b8ef7e4b066162ab1fdc088319ce6defc876"_u8v, + "78d0b55eeb3a07754f0967a6e960b5b7488b09ec4d2a62d832a45d80f814aef88e5414e2115165012ac592ff050651e956089a5aacd4ea52cf247c3cc2f6add2"_u8v, + "78d0b55eeb3a07754f0967a6e960b5b7488b09ec4d2a62d832a45d80f814aef88e5414e2115165012ac592ff050651e956089a5aacd4ea52cf247c3cc2f6ad"_u8v); + HashTest( + "SHA-512/256"sv, "data"_u8, "more_data"_u8, + "99902eaf90e92264667843cde66675ed94caa361634bad57874642aa364aa968"_u8v, + "d1def71920a44d8b6c83b2eaa99379a16047cc82cec8d80689fbf02fbd062481"_u8v, + "d1def71920a44d8b6c83b2eaa99379a16047cc82cec8d80689fbf02fbd0624"_u8v); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/helper.cpp b/test/plugins/wasi_crypto/helper.cpp new file mode 100644 index 00000000..f3ca00a7 --- /dev/null +++ b/test/plugins/wasi_crypto/helper.cpp @@ -0,0 +1,1446 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" +#include "asymmetric_common/func.h" +#include "common/func.h" +#include "kx/func.h" +#include "signatures/func.h" +#include "symmetric/func.h" +#include "utils/error.h" + +#include +#include +#include + +#define ensureOrReturnOnTest(Expr) \ + do { \ + if ((static_cast<__wasi_crypto_errno_e_t>(Expr) != \ + __WASI_CRYPTO_ERRNO_SUCCESS)) { \ + return WasiCryptoUnexpect(static_cast<__wasi_crypto_errno_e_t>(Expr)); \ + } \ + } while (0) + +namespace { +template +inline T *getHostFunc(M &Mod, const char *Name) { + if (Mod) { + auto *FuncInst = Mod->findFuncExports(Name); + if (FuncInst && FuncInst->isHostFunction()) { + return dynamic_cast(&FuncInst->getHostFunc()); + } + } + return nullptr; +} +} // namespace + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +std::vector operator"" _u8(const char *Str, std::size_t Len) { + return std::vector{reinterpret_cast(Str), + reinterpret_cast(Str) + Len}; +} + +std::vector operator"" _u8v(const char *Str, std::size_t Len) { + std::vector Res; + Res.reserve(Len / 2); + for (size_t I = 0; I < Len; I += 2) { + std::string Tran{Str + I, 2}; + uint8_t Byte = static_cast(std::strtol(Tran.c_str(), nullptr, 16)); + Res.push_back(Byte); + } + return Res; +} + +void WasiCryptoTest::writeDummyMemoryContent() { + std::fill_n(MemInst->getPointer(0), 64, UINT8_C(0xa5)); +} + +void WasiCryptoTest::writeString(std::string_view String, uint32_t Ptr) { + std::copy(String.begin(), String.end(), MemInst->getPointer(Ptr)); +} + +void WasiCryptoTest::writeSpan(Span Content, uint32_t Ptr) { + std::copy(Content.begin(), Content.end(), + MemInst->getPointer(Ptr)); +} + +void WasiCryptoTest::writeOptKey(std::optional OptKey, uint32_t Ptr) { + __wasi_opt_symmetric_key_t Key; + if (OptKey) { + Key.tag = __WASI_OPT_SYMMETRIC_KEY_U_SOME; + Key.u = {*OptKey}; + } else { + Key.tag = __WASI_OPT_SYMMETRIC_KEY_U_NONE; + } + auto *BeginPlace = MemInst->getPointer<__wasi_opt_symmetric_key_t *>(Ptr); + *BeginPlace = Key; +} + +void WasiCryptoTest::writeOptOptions(std::optional<__wasi_options_t> OptOptions, + uint32_t Ptr) { + __wasi_opt_options_t Options; + if (OptOptions) { + Options.tag = __WASI_OPT_OPTIONS_U_SOME; + Options.u = {*OptOptions}; + } else { + Options.tag = __WASI_OPT_OPTIONS_U_NONE; + } + auto *BeginPlace = MemInst->getPointer<__wasi_opt_options_t *>(Ptr); + *BeginPlace = Options; +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoCommonMod, + "array_output_len"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{ArrayOutputHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_size_t *>(0); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, + Span Buf) { + writeDummyMemoryContent(); + writeSpan(Buf, 0); + uint32_t BufSize = static_cast(Buf.size()); + + auto *Func = getHostFunc(WasiCryptoCommonMod, + "array_output_pull"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + ArrayOutputHandle, 0, BufSize, BufSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(BufSize), Buf.begin()); + return *MemInst->getPointer<__wasi_size_t *>(BufSize); +} + +WasiCryptoExpect<__wasi_options_t> +WasiCryptoTest::optionsOpen(__wasi_algorithm_type_e_t AlgorithmType) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoCommonMod, "options_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + static_cast(AlgorithmType), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_options_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::optionsClose(__wasi_options_t OptionsHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoCommonMod, "options_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{OptionsHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, Span Value) { + writeDummyMemoryContent(); + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + writeSpan(Value, NameSize); + uint32_t ValueSize = static_cast(Value.size()); + + auto *Func = + getHostFunc(WasiCryptoCommonMod, "options_set"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + OptionsHandle, 0, NameSize, NameSize, ValueSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, uint64_t Value) { + writeDummyMemoryContent(); + + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + + auto *Func = getHostFunc(WasiCryptoCommonMod, + "options_set_u64"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + OptionsHandle, 0, NameSize, Value}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_secrets_manager_t> WasiCryptoTest::secretsManagerOpen( + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeOptOptions(OptOptionsHandle, 0); + + auto *Func = getHostFunc(WasiCryptoCommonMod, + "secrets_manager_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{0, 8}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_secrets_manager_t *>(8); +} + +WasiCryptoExpect WasiCryptoTest::secretsManagerClose( + __wasi_secrets_manager_t SecretsManagerHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoCommonMod, "secrets_manager_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{SecretsManagerHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect WasiCryptoTest::secretsManagerInvalidate( + __wasi_secrets_manager_t SecretsManagerHandle, Span KeyId, + __wasi_version_t Version) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KeyIdSize = static_cast(KeyId.size()); + + auto *Func = getHostFunc( + WasiCryptoCommonMod, "secrets_manager_invalidate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, 0, KeyIdSize, Version}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyGenerate( + std::string_view Alg, std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_generate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + 0, AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_symmetric_key_t *>(AlgSize + 8); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +WasiCryptoTest::symmetricKeyImport(std::string_view Alg, + Span Raw) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Raw, AlgSize); + uint32_t RawSize = static_cast(Raw.size()); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + 0, AlgSize, AlgSize, RawSize, AlgSize + RawSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_symmetric_key_t *>(AlgSize + RawSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{KeyHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricKeyClose(__wasi_symmetric_key_t KeyHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{KeyHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +WasiCryptoTest::symmetricKeyGenerateManaged( + __wasi_secrets_manager_t SecretsManagerHandle, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_key_generate_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, 0, AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + 8); +} + +WasiCryptoExpect WasiCryptoTest::symmetricKeyStoreManaged( + __wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t KeyHandle, Span KeyId) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KpIdSize = static_cast(KeyId.size()); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_key_store_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, KeyHandle, 0, KpIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KpIdSize), KeyId.begin()); + + return {}; +} + +WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::symmetricKeyReplaceManaged( + __wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t OldKeyHandle, __wasi_symmetric_key_t NewKeyHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_key_replace_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, OldKeyHandle, NewKeyHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_version_t *>(0); +} + +WasiCryptoExpect> +WasiCryptoTest::symmetricKeyId(__wasi_symmetric_key_t KeyHandle, + Span KeyId) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KeyIdSize = static_cast(KeyId.size()); + + auto *Func = + getHostFunc(WasiCryptoSymmMod, "symmetric_key_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + KeyHandle, 0, KeyIdSize, KeyIdSize, KeyIdSize + 1}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KeyIdSize), KeyId.begin()); + + return std::make_tuple( + *MemInst->getPointer(KeyIdSize), + *MemInst->getPointer<__wasi_version_t *>(KeyIdSize + 1)); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyFromId( + __wasi_secrets_manager_t SecretsManagerHandle, Span KeyId, + __wasi_version_t KeyVersion) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KeyIdSize = static_cast(KeyId.size()); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_from_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, 0, KeyIdSize, KeyVersion, KeyIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_symmetric_key_t *>(KeyIdSize); +} + +WasiCryptoExpect<__wasi_symmetric_state_t> WasiCryptoTest::symmetricStateOpen( + std::string_view Alg, std::optional<__wasi_symmetric_key_t> OptKeyHandle, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptKey(OptKeyHandle, AlgSize); + writeOptOptions(OptOptionsHandle, AlgSize + 8); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + 0, AlgSize, AlgSize, AlgSize + 8, AlgSize + 16}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_symmetric_state_t *>(AlgSize + 16); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, + Span Value) { + writeDummyMemoryContent(); + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + writeSpan(Value, NameSize); + uint32_t ValueSize = static_cast(Value.size()); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_options_get"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{ + StateHandle, 0, NameSize, NameSize, ValueSize, NameSize + ValueSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(NameSize), + MemInst->getPointer(NameSize + ValueSize), + Value.begin()); + + return *MemInst->getPointer<__wasi_size_t *>(NameSize + ValueSize); +} + +WasiCryptoExpect WasiCryptoTest::symmetricStateOptionsGetU64( + __wasi_symmetric_state_t StateHandle, std::string_view Name) { + writeDummyMemoryContent(); + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_options_get_u64"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + StateHandle, 0, NameSize, NameSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer(NameSize); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateClose(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Data, 0); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_absorb"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{StateHandle, 0, DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_state_t> +WasiCryptoTest::symmetricStateClone(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_clone"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_symmetric_state_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_squeeze"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{StateHandle, 0, OutSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> +WasiCryptoTest::symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_squeeze_tag"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_symmetric_tag_t *>(0); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +WasiCryptoTest::symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + std::string_view Alg) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_squeeze_key"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + StateHandle, 0, AlgSize, AlgSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_symmetric_key_t *>(AlgSize); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_max_tag_len"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_size_t *>(0); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_encrypt"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); + + return *MemInst->getPointer<__wasi_size_t *>(OutSize + DataSize); +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> +WasiCryptoTest::symmetricStateEncryptDetached( + __wasi_symmetric_state_t StateHandle, Span Out, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_encrypt_detached"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); + + return *MemInst->getPointer<__wasi_symmetric_tag_t *>(OutSize + DataSize); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_decrypt"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); + + return *MemInst->getPointer<__wasi_size_t *>(OutSize + DataSize); +} + +WasiCryptoExpect<__wasi_size_t> WasiCryptoTest::symmetricStateDecryptDetached( + __wasi_symmetric_state_t StateHandle, Span Out, + Span Data, Span RawTag) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + writeSpan(RawTag, OutSize + DataSize); + uint32_t RawTagSize = static_cast(RawTag.size()); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_decrypt_detached"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, + OutSize + DataSize, RawTagSize, + OutSize + DataSize + RawTagSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); + std::copy(MemInst->getPointer(OutSize + DataSize), + MemInst->getPointer(OutSize + DataSize + RawTagSize), + RawTag.begin()); + + return *MemInst->getPointer<__wasi_size_t *>(OutSize + DataSize + RawTagSize); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_ratchet"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricMaxTagLen(__wasi_symmetric_tag_t TagHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoSymmMod, "symmetric_state_max_tag_len"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{TagHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_size_t *>(0); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricTagPull(__wasi_symmetric_tag_t TagHandle, + Span Buf) { + writeDummyMemoryContent(); + writeSpan(Buf, 0); + uint32_t BufSize = static_cast(Buf.size()); + + auto *Func = + getHostFunc(WasiCryptoSymmMod, "symmetric_tag_pull"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + TagHandle, 0, BufSize, BufSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(BufSize), Buf.begin()); + + return *MemInst->getPointer<__wasi_size_t *>(BufSize); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag) { + writeDummyMemoryContent(); + writeSpan(RawTag, 0); + uint32_t RawTagSize = static_cast(RawTag.size()); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_tag_verify"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{TagHandle, 0, RawTagSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::symmetricTagClose(__wasi_symmetric_tag_t TagHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_tag_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{TagHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerate( + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_generate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + 8); +} + +WasiCryptoExpect<__wasi_keypair_t> +WasiCryptoTest::keypairImport(__wasi_algorithm_type_e_t AlgType, + std::string_view Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, + EncodedSize, static_cast(Encoding), + AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerateManaged( + __wasi_secrets_manager_t SecretsManagerHandle, + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_generate_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, static_cast(AlgType), 0, + AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + 8); +} + +WasiCryptoExpect WasiCryptoTest::keypairStoreManaged( + __wasi_secrets_manager_t SecretsManagerHandle, __wasi_keypair_t KpHandle, + Span KpId) { + writeDummyMemoryContent(); + writeSpan(KpId, 0); + uint32_t KpIdSize = static_cast(KpId.size()); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_store_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, KpHandle, 0, KpIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KpIdSize), KpId.begin()); + + return {}; +} + +WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::keypairReplaceManaged( + __wasi_secrets_manager_t SecretsManagerHandle, __wasi_keypair_t OldKpHandle, + __wasi_keypair_t NewKpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_replace_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, OldKpHandle, NewKpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_version_t *>(0); +} + +WasiCryptoExpect> +WasiCryptoTest::keypairId(__wasi_keypair_t KpHandle, Span KpId) { + writeDummyMemoryContent(); + writeSpan(KpId, 0); + uint32_t KpIdSize = static_cast(KpId.size()); + + auto *Func = getHostFunc(WasiCryptoAsymCommonMod, + "keypair_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + KpHandle, 0, KpIdSize, KpIdSize, KpIdSize + 1}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KpIdSize), KpId.begin()); + + return std::make_tuple( + *MemInst->getPointer(KpIdSize), + *MemInst->getPointer<__wasi_version_t *>(KpIdSize + 1)); +} + +WasiCryptoExpect<__wasi_keypair_t> +WasiCryptoTest::keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KpId, + __wasi_version_t KpIdVersion) { + writeDummyMemoryContent(); + writeSpan(KpId, 0); + uint32_t KpIdSize = static_cast(KpId.size()); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_from_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(CallFrame, + std::initializer_list{ + SecretsManagerHandle, 0, KpIdSize, KpIdVersion, KpIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_keypair_t *>(KpIdSize); +} + +WasiCryptoExpect<__wasi_keypair_t> +WasiCryptoTest::keypairFromPkAndSk(__wasi_publickey_t PkHandle, + __wasi_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_from_pk_and_sk"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{PkHandle, SkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_keypair_t *>(0); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + KpHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect<__wasi_publickey_t> +WasiCryptoTest::keypairPublickey(__wasi_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_publickey"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{KpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_publickey_t *>(0); +} + +WasiCryptoExpect<__wasi_secretkey_t> +WasiCryptoTest::keypairSecretkey(__wasi_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_secretkey"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{KpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_secretkey_t *>(0); +} + +WasiCryptoExpect WasiCryptoTest::keypairClose(__wasi_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{KpHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_publickey_t> WasiCryptoTest::publickeyImport( + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, __wasi_publickey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "publickey_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, + EncodedSize, static_cast(Encoding), + AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_publickey_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "publickey_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + PkHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::publickeyVerify(__wasi_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "publickey_verify"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{PkHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_publickey_t> +WasiCryptoTest::publickeyFromSecretkey(__wasi_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "publickey_from_secretkey"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{SkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_publickey_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::publickeyClose(__wasi_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "publickey_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{PkHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_secretkey_t> WasiCryptoTest::secretkeyImport( + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, __wasi_secretkey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "secretkey_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, + EncodedSize, static_cast(Encoding), + AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_secretkey_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "secretkey_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SkHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_publickey_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::secretkeyClose(__wasi_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "secretkey_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{SkHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::kxDh(__wasi_kx_publickey_t PkHandle, + __wasi_kx_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoKxMod, "kx_dh"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{PkHandle, SkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect> +WasiCryptoTest::kxEncapsulate(__wasi_kx_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoKxMod, "kx_encapsulate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{PkHandle, 0, 1}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return std::make_tuple(*MemInst->getPointer<__wasi_array_output_t *>(0), + *MemInst->getPointer<__wasi_array_output_t *>(1)); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret) { + writeDummyMemoryContent(); + writeSpan(EncapsulatedSecret, 0); + uint32_t EncapsulatedSecretSize = + static_cast(EncapsulatedSecret.size()); + + auto *Func = getHostFunc(WasiCryptoKxMod, "kx_decapsulate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{ + SkHandle, 0, EncapsulatedSecretSize, EncapsulatedSecretSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_array_output_t *>(EncapsulatedSecretSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoSignMod, "signature_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{ + SigHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect<__wasi_signature_t> +WasiCryptoTest::signatureImport(std::string_view Alg, + Span Encoded, + __wasi_signature_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = + getHostFunc(WasiCryptoSignMod, "signature_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(CallFrame, + std::initializer_list{ + 0, AlgSize, AlgSize, EncodedSize, + static_cast(Encoding), AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_signature_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect +WasiCryptoTest::signatureClose(__wasi_signature_t SigHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoSignMod, "signature_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{SigHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_signature_state_t> +WasiCryptoTest::signatureStateOpen(__wasi_signature_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_state_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{KpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_signature_state_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input) { + writeDummyMemoryContent(); + writeSpan(Input, 0); + uint32_t InputSize = static_cast(Input.size()); + + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_state_update"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{StateHandle, 0, InputSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_signature_t> +WasiCryptoTest::signatureStateSign(__wasi_signature_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_state_sign"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_signature_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::signatureStateClose(__wasi_signature_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_state_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_signature_verification_state_t> +WasiCryptoTest::signatureVerificationStateOpen( + __wasi_signature_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoSignMod, "signature_verification_state_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{PkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst->getPointer<__wasi_signature_verification_state_t *>(0); +} + +WasiCryptoExpect WasiCryptoTest::signatureVerificationStateUpdate( + __wasi_signature_verification_state_t StateHandle, + Span Input) { + writeDummyMemoryContent(); + writeSpan(Input, 0); + uint32_t InputSize = static_cast(Input.size()); + + auto *Func = getHostFunc( + WasiCryptoSignMod, "signature_verification_state_update"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{StateHandle, 0, InputSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect WasiCryptoTest::signatureVerificationStateVerify( + __wasi_signature_verification_state_t StateHandle, + __wasi_signature_t SigHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoSignMod, "signature_verification_state_verify"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, + std::initializer_list{StateHandle, SigHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect WasiCryptoTest::signatureVerificationStateClose( + __wasi_signature_verification_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoSignMod, "signature_verification_state_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + CallFrame, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +// WasiCryptoExpect<__wasi_secretkey_t> WasiCryptoTest::secretkeyImport( +// __wasi_algorithm_type_e_t AlgType, std::string_view AlgStr, +// Span Encoded, __wasi_secretkey_encoding_e_t Encoding) { +// writeString(AlgStr, 0); +// writeSpan(Encoded, AlgStr.size()); +// auto Res = +// testRun( +// {static_cast(AlgType), 0, AlgStr.size(), AlgStr.size(), +// Encoded.size(), static_cast(Encoding), +// AlgStr.size() + Encoded.size()}) +// .value(); +// if (Res != __WASI_CRYPTO_ERRNO_SUCCESS) { +// return WasiCryptoUnexpect(Res); +// } +// return *MemInst->getPointer<__wasi_signature_keypair_t *>(AlgStr.size() + +// Encoded.size()); +// } + +// WasiCryptoExpect<__wasi_array_output_t> +// WasiCryptoTest::secretkeyExport(__wasi_secretkey_t SkHandle, +// __wasi_secretkey_encoding_e_t SkEncoding) { +// auto Res = testRun( +// {SkHandle, static_cast(SkEncoding), 0}) +// .value(); +// if (Res != __WASI_CRYPTO_ERRNO_SUCCESS) { +// return WasiCryptoUnexpect(Res); +// } +// return *MemInst->getPointer<__wasi_signature_keypair_t *>(0); +// } + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h new file mode 100644 index 00000000..19b1425f --- /dev/null +++ b/test/plugins/wasi_crypto/helper.h @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "asymmetric_common/module.h" +#include "common/module.h" +#include "ctx.h" +#include "helper.h" +#include "kx/module.h" +#include "signatures/module.h" +#include "symmetric/module.h" +#include "utils/error.h" + +#include "common/span.h" +#include "common/types.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WASI_CRYPTO_EXPECT_SUCCESS(Expr, Function) \ + auto &&Expr##__Result = Function; \ + EXPECT_TRUE(Expr##__Result) \ + << "Wasi Crypto Error code: " << Errno[0].get(); \ + auto &&Expr = Expr##__Result.value() + +#define WASI_CRYPTO_EXPECT_FAILURE(Function, ErrorCode) \ + do { \ + auto Result = Function; \ + EXPECT_FALSE(Result) << "The function result should be error but success"; \ + EXPECT_EQ(Result.error(), ErrorCode); \ + } while (0) + +#define WASI_CRYPTO_EXPECT_TRUE(Function) \ + EXPECT_TRUE(Function) << "Wasi Crypto Error code: " << Errno[0].get() + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +std::vector operator"" _u8(const char *Str, std::size_t Len); + +std::vector operator"" _u8v(const char *Str, std::size_t Len); + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +/// Designed for testing. +class WasiCryptoTest : public ::testing::Test { +public: + WasiCryptoTest() : Mod(""), CallFrame(nullptr, &Mod) { + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + MemInst = Mod.findMemoryExports("memory"); + + using namespace std::literals::string_view_literals; + Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasi_crypto/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiCrypto" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_crypto"sv)) { + if (const auto *Module = + Plugin->findModule("wasi_crypto_asymmetric_common"sv)) { + WasiCryptoAsymCommonMod = dynamicPointerCast< + WasmEdge::Host::WasiCryptoAsymmetricCommonModule>(Module->create()); + } + if (const auto *Module = Plugin->findModule("wasi_crypto_common"sv)) { + WasiCryptoCommonMod = + dynamicPointerCast( + Module->create()); + } + if (const auto *Module = Plugin->findModule("wasi_crypto_kx"sv)) { + WasiCryptoKxMod = + dynamicPointerCast( + Module->create()); + } + if (const auto *Module = Plugin->findModule("wasi_crypto_signatures"sv)) { + WasiCryptoSignMod = + dynamicPointerCast( + Module->create()); + } + if (const auto *Module = Plugin->findModule("wasi_crypto_symmetric"sv)) { + WasiCryptoSymmMod = + dynamicPointerCast( + Module->create()); + } + } + } + +protected: + void writeDummyMemoryContent(); + + void writeString(std::string_view String, uint32_t Ptr); + + void writeSpan(Span Content, uint32_t Ptr); + + void writeOptKey(std::optional OptKey, uint32_t Ptr); + + void writeOptOptions(std::optional<__wasi_options_t> OptOptions, + uint32_t Ptr); + + // Common + + WasiCryptoExpect<__wasi_size_t> + arrayOutputLen(__wasi_array_output_t ArrayOutputHandle); + + WasiCryptoExpect<__wasi_size_t> + arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, Span Buf); + + WasiCryptoExpect<__wasi_options_t> + optionsOpen(__wasi_algorithm_type_e_t AlgType); + + WasiCryptoExpect optionsClose(__wasi_options_t OptionsHandle); + + WasiCryptoExpect optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Value); + + WasiCryptoExpect optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, uint64_t Value); + + // Not supported, buffer placement must be on a page. + // WasiCryptoExpect + // optionsSetGuestBuffer(__wasi_options_t OptionsHandle, + // std::string_view Name, Span Buf); + + WasiCryptoExpect<__wasi_secrets_manager_t> + secretsManagerOpen(std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect + secretsManagerClose(__wasi_secrets_manager_t SecretsManagerHandle); + + WasiCryptoExpect + secretsManagerInvalidate(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, __wasi_version_t Version); + + // Symmetric + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerate(std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyImport(std::string_view Alg, Span Raw); + + WasiCryptoExpect<__wasi_array_output_t> + symmetricKeyExport(__wasi_symmetric_key_t KeyHandle); + + WasiCryptoExpect symmetricKeyClose(__wasi_symmetric_key_t KeyHandle); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect + symmetricKeyStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t KeyHandle, + Span KeyId); + + WasiCryptoExpect<__wasi_version_t> + symmetricKeyReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t OldKeyHandle, + __wasi_symmetric_key_t NewKeyHandle); + + WasiCryptoExpect> + symmetricKeyId(__wasi_symmetric_key_t KeyHandle, Span KeyId); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, __wasi_version_t KeyVersion); + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateOpen(std::string_view Alg, + std::optional<__wasi_symmetric_key_t> OptKeyHandle, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, Span Value); + + WasiCryptoExpect + symmetricStateOptionsGetU64(__wasi_symmetric_state_t StateHandle, + std::string_view Name); + + WasiCryptoExpect + symmetricStateClose(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect + symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data); + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateClone(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect + symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out); + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + std::string_view Alg); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data); + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateEncryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, Span Data); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateDecryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, Span Data, + Span RawTag); + + WasiCryptoExpect + symmetricStateRatchet(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricMaxTagLen(__wasi_symmetric_tag_t TagHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricTagPull(__wasi_symmetric_tag_t TagHandle, Span Buf); + + WasiCryptoExpect symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag); + + WasiCryptoExpect symmetricTagClose(__wasi_symmetric_tag_t TagHandle); + + // Asymmetric_common + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerate(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect<__wasi_keypair_t> + keypairImport(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, + __wasi_keypair_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_algorithm_type_e_t AlgType, + std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect + keypairStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t KpHandle, Span KpId); + + WasiCryptoExpect<__wasi_version_t> + keypairReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t OldKpHandle, + __wasi_keypair_t NewKpHandle); + + WasiCryptoExpect> + keypairId(__wasi_keypair_t KpHandle, Span KpId); + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KpId, __wasi_version_t KpIdVersion); + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromPkAndSk(__wasi_publickey_t PkHandle, __wasi_secretkey_t SkHandle); + + WasiCryptoExpect<__wasi_array_output_t> + keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_publickey_t> + keypairPublickey(__wasi_keypair_t KpHandle); + + WasiCryptoExpect<__wasi_secretkey_t> + keypairSecretkey(__wasi_keypair_t KpHandle); + + WasiCryptoExpect keypairClose(__wasi_keypair_t KpHandle); + + WasiCryptoExpect<__wasi_publickey_t> + publickeyImport(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, + __wasi_publickey_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_array_output_t> + publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding); + + WasiCryptoExpect publickeyVerify(__wasi_publickey_t PkHandle); + + WasiCryptoExpect<__wasi_publickey_t> + publickeyFromSecretkey(__wasi_secretkey_t SkHandle); + + WasiCryptoExpect publickeyClose(__wasi_publickey_t PkHandle); + + WasiCryptoExpect<__wasi_secretkey_t> + secretkeyImport(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, + __wasi_secretkey_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_array_output_t> + secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding); + + WasiCryptoExpect secretkeyClose(__wasi_secretkey_t SkHandle); + + // Key_exchange + + WasiCryptoExpect<__wasi_array_output_t> kxDh(__wasi_kx_publickey_t PkHandle, + __wasi_kx_secretkey_t SkHandle); + + WasiCryptoExpect> + kxEncapsulate(__wasi_kx_publickey_t PkHandle); + + WasiCryptoExpect<__wasi_array_output_t> + kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret); + + // Signature + + WasiCryptoExpect<__wasi_array_output_t> + signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_signature_t> + signatureImport(std::string_view Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding); + + WasiCryptoExpect signatureClose(__wasi_signature_t SigHandle); + + WasiCryptoExpect<__wasi_signature_state_t> + signatureStateOpen(__wasi_signature_keypair_t KpHandle); + + WasiCryptoExpect + signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input); + + WasiCryptoExpect<__wasi_signature_t> + signatureStateSign(__wasi_signature_state_t StateHandle); + + WasiCryptoExpect + signatureStateClose(__wasi_signature_state_t StateHandle); + + WasiCryptoExpect<__wasi_signature_verification_state_t> + signatureVerificationStateOpen(__wasi_signature_publickey_t PkHandle); + + WasiCryptoExpect signatureVerificationStateUpdate( + __wasi_signature_verification_state_t StateHandle, + Span Input); + + WasiCryptoExpect signatureVerificationStateVerify( + __wasi_signature_verification_state_t StateHandle, + __wasi_signature_t SigHandle); + + WasiCryptoExpect signatureVerificationStateClose( + __wasi_signature_verification_state_t StateHandle); + + int32_t InvaildHandle = 9999; + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod; + WasmEdge::Runtime::Instance::MemoryInstance *MemInst; + WasmEdge::Runtime::CallingFrame CallFrame; + + std::array Errno; + + std::unique_ptr + WasiCryptoAsymCommonMod; + std::unique_ptr WasiCryptoCommonMod; + std::unique_ptr WasiCryptoKxMod; + std::unique_ptr WasiCryptoSignMod; + std::unique_ptr WasiCryptoSymmMod; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/kdf.cpp b/test/plugins/wasi_crypto/kdf.cpp new file mode 100644 index 00000000..8d64fa9f --- /dev/null +++ b/test/plugins/wasi_crypto/kdf.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Kdf) { + auto KdfTest = [this](std::string_view ExtractAlg, std::string_view ExpandAlg, + const std::vector &Key, + const std::vector &Salt, + const std::vector &Info, size_t KeySize) { + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, symmetricKeyImport(ExtractAlg, Key)); + WASI_CRYPTO_EXPECT_SUCCESS( + ExtractStateHandle, + symmetricStateOpen(ExtractAlg, KeyHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(ExtractStateHandle, Salt)); + WASI_CRYPTO_EXPECT_SUCCESS( + PrkHandle, symmetricStateSqueezeKey(ExtractStateHandle, ExpandAlg)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(KeyHandle)); + WASI_CRYPTO_EXPECT_SUCCESS( + ExpandStateHandle, + symmetricStateOpen(ExpandAlg, PrkHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(ExpandStateHandle, Info)); + std::vector SqueezeKey(KeySize); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(ExpandStateHandle, SqueezeKey)); + + auto BothInvalid = [this](std::string_view Name, + __wasi_symmetric_state_t StateHandle) { + EXPECT_TRUE( + symmetricStateOpen(Name, InvaildHandle, std::nullopt).error() == + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGet(StateHandle, "foo"sv, {}), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGetU64(StateHandle, "foo"sv), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeTag(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateMaxTagLen(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateEncrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateEncryptDetached(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateDecrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateDecryptDetached(StateHandle, {}, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + + // Clone checking. + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateClone(StateHandle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + }; + BothInvalid(ExpandAlg, ExtractStateHandle); + BothInvalid(ExtractAlg, ExpandStateHandle); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueeze(ExtractStateHandle, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateSqueezeKey(ExpandStateHandle, ExpandAlg), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(ExtractStateHandle)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(ExpandStateHandle)); + + WASI_CRYPTO_EXPECT_SUCCESS(NewKeyHandle, + symmetricKeyGenerate(ExtractAlg, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(NewKeyHandle)); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyGenerate(ExpandAlg, std::nullopt), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_FEATURE); + }; + KdfTest("HKDF-EXTRACT/SHA-256"sv, "HKDF-EXPAND/SHA-256"sv, "IKM"_u8, + "salt"_u8, "info"_u8, 32); + KdfTest("HKDF-EXTRACT/SHA-512"sv, "HKDF-EXPAND/SHA-512"sv, "IKM"_u8, + "salt"_u8, "info"_u8, 64); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/kx.cpp b/test/plugins/wasi_crypto/kx.cpp new file mode 100644 index 00000000..a47b728c --- /dev/null +++ b/test/plugins/wasi_crypto/kx.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, KxDh) { + + auto KxDhTest = [this](std::string_view Alg, const std::vector &Pk1, + const std::vector &Sk1, + const std::vector &Pk2, + const std::vector &Sk2, + const std::vector &SharedSecret) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_SUCCESS( + Pk1Handle, publickeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Pk1, + __WASI_PUBLICKEY_ENCODING_RAW)); + WASI_CRYPTO_EXPECT_SUCCESS( + Sk1Handle, secretkeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Sk1, + __WASI_SECRETKEY_ENCODING_RAW)); + WASI_CRYPTO_EXPECT_SUCCESS( + Pk2Handle, publickeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Pk2, + __WASI_PUBLICKEY_ENCODING_RAW)); + WASI_CRYPTO_EXPECT_SUCCESS( + Sk2Handle, secretkeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Sk2, + __WASI_SECRETKEY_ENCODING_RAW)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Handle, kxDh(Pk1Handle, Sk2Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Size, + arrayOutputLen(SharedKey1Handle)); + EXPECT_EQ(SharedKey1Size, 32); + std::vector SharedKey1(32); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1PullSize, + arrayOutputPull(SharedKey1Handle, SharedKey1)); + EXPECT_EQ(SharedKey1PullSize, 32); + EXPECT_EQ(SharedKey1, SharedSecret); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Handle, kxDh(Pk2Handle, Sk1Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Size, + arrayOutputLen(SharedKey2Handle)); + EXPECT_EQ(SharedKey2Size, 32); + std::vector SharedKey2(32); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(SharedKey2Handle, SharedKey2)); + EXPECT_EQ(SharedKey2, SharedSecret); + + /// It's only supported in OpenSSL 3.0. + /// See: https://github.com/openssl/openssl/issues/7616 + WASI_CRYPTO_EXPECT_FAILURE(kxEncapsulate(Pk1Handle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(kxDecapsulate(Sk1Handle, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk1Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk2Handle)); + + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk2Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk1Handle)); + }; + + // From: https://datatracker.ietf.org/doc/html/rfc7748#section-6.1 + KxDhTest( + "X25519"sv, + "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a"_u8v, + "77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"_u8v, + "de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f"_u8v, + "5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6fd1c2f8b27ff88e0eb"_u8v, + "4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742"_u8v); + + auto NewKxDhTest = [this](std::string_view Alg) { + SCOPED_TRACE(Alg); + + WASI_CRYPTO_EXPECT_SUCCESS( + Kp1Handle, + keypairGenerate(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS( + Kp2Handle, + keypairGenerate(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, std::nullopt)); + + WASI_CRYPTO_EXPECT_SUCCESS(Pk1Handle, keypairPublickey(Kp1Handle)); + WASI_CRYPTO_EXPECT_SUCCESS(Sk1Handle, keypairSecretkey(Kp1Handle)); + WASI_CRYPTO_EXPECT_SUCCESS(Pk2Handle, keypairPublickey(Kp2Handle)); + WASI_CRYPTO_EXPECT_SUCCESS(Sk2Handle, keypairSecretkey(Kp2Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Handle, kxDh(Pk1Handle, Sk2Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Size, + arrayOutputLen(SharedKey1Handle)); + EXPECT_EQ(SharedKey1Size, 32); + std::vector SharedKey1(32); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1PullSize, + arrayOutputPull(SharedKey1Handle, SharedKey1)); + EXPECT_EQ(SharedKey1PullSize, 32); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Handle, kxDh(Pk2Handle, Sk1Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Size, + arrayOutputLen(SharedKey2Handle)); + EXPECT_EQ(SharedKey2Size, 32); + std::vector SharedKey2(32); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(SharedKey2Handle, SharedKey2)); + + EXPECT_EQ(SharedKey1, SharedKey2); + + /// It's only supported in OpenSSL 3.0. + /// See: https://github.com/openssl/openssl/issues/7616 + WASI_CRYPTO_EXPECT_FAILURE(kxEncapsulate(Pk1Handle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(kxDecapsulate(Sk1Handle, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk1Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk2Handle)); + + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk2Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk1Handle)); + }; + NewKxDhTest("P256-SHA256"sv); + NewKxDhTest("P384-SHA384"sv); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/mac.cpp b/test/plugins/wasi_crypto/mac.cpp new file mode 100644 index 00000000..c2616e31 --- /dev/null +++ b/test/plugins/wasi_crypto/mac.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Mac) { + auto MacTest = [this](std::string_view Name, + const std::vector &ImportKey, + const std::vector &AbsorbData1, + const std::vector &AbsorbData2, + const std::vector &ExpectedTag1, + const std::vector &ExpectedTag2) { + SCOPED_TRACE(Name); + // Generate key hmac. + { + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, + symmetricKeyGenerate(Name, std::nullopt)); + + // Key size checking. + WASI_CRYPTO_EXPECT_SUCCESS(KeyOutputHandle, + symmetricKeyExport(KeyHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(KeySize, arrayOutputLen(KeyOutputHandle)); + EXPECT_EQ(KeySize, ImportKey.size()); + + WASI_CRYPTO_EXPECT_SUCCESS( + StateHandle, symmetricStateOpen(Name, KeyHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(KeyHandle)); + + // Equivalent to a single call. + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData1)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData2)); + + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, + symmetricStateSqueezeTag(StateHandle)); + std::vector Tag(ExpectedTag1.size()); + + WASI_CRYPTO_EXPECT_SUCCESS(TagPullSize, symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(TagPullSize, ExpectedTag1.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricTagClose(TagHandle)); + + WASI_CRYPTO_EXPECT_SUCCESS(NewTagHandle, + symmetricStateSqueezeTag(StateHandle)); + + WASI_CRYPTO_EXPECT_TRUE(symmetricTagVerify(NewTagHandle, Tag)); + WASI_CRYPTO_EXPECT_TRUE(symmetricTagClose(NewTagHandle)); + + // Error case checking. + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOpen(Name, std::nullopt, std::nullopt), + __WASI_CRYPTO_ERRNO_KEY_REQUIRED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueeze(StateHandle, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeKey(StateHandle, Name), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateMaxTagLen(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateEncrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateEncryptDetached(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateDecrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateDecryptDetached(StateHandle, {}, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + + // Clone checking. + WASI_CRYPTO_EXPECT_SUCCESS(NewStateHandle, + symmetricStateClone(StateHandle)); + EXPECT_NE(StateHandle, NewStateHandle); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(NewStateHandle)); + + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(StateHandle)); + } + + // Import key hmac. + { + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, + symmetricKeyImport(Name, ImportKey)); + WASI_CRYPTO_EXPECT_SUCCESS( + StateHandle, symmetricStateOpen(Name, KeyHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(KeyHandle)); + { + // Absorb "data". + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData1)); + + // SqueezeTag "data". + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, + symmetricStateSqueezeTag(StateHandle)); + std::vector Tag(ExpectedTag1.size()); + WASI_CRYPTO_EXPECT_SUCCESS(TagPullSize, + symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(TagPullSize, Tag.size()); + EXPECT_EQ(Tag, ExpectedTag1); + } + + { + // Abosorb "more_data". + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData2)); + + // SqueezeTag "datamore_data". + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, + symmetricStateSqueezeTag(StateHandle)); + std::vector Tag(ExpectedTag2.size()); + WASI_CRYPTO_EXPECT_SUCCESS(TagPullSize, + symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(TagPullSize, Tag.size()); + EXPECT_EQ(Tag, ExpectedTag2); + } + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(StateHandle)); + } + }; + MacTest( + "HMAC/SHA-256"sv, "00000000000000000000000000000000"_u8, "data"_u8, + "more_data"_u8, + "7f12a3d914ec4d1ee67dd35ff04df5a725d11a6bb78a4aafd1093f5bfbd86887"_u8v, + "77af4875ffb3932cba0c8bc5da18410c42c85eeb07072918629675e054fbc42d"_u8v); + MacTest( + "HMAC/SHA-512"sv, + "0000000000000000000000000000000000000000000000000000000000000000"_u8, + "data"_u8, "more_data"_u8, + "52fbafda16189e63730604e49c747c8281d2420e7aae34c927927e7c3cddfcea62fea554d1962a0c0d1c8177884787d8b2a88bd396d5780e3fb82b11ab33c5cc"_u8v, + "36d2dbfb50768b963fe243535bcda302750297b361b7eb079978b27177adc40338dab5c244ae90e2f11a3518ac31126a52eb5ec715c0a9476b98f73e7ff7682e"_u8v); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/notimplement.cpp b/test/plugins/wasi_crypto/notimplement.cpp new file mode 100644 index 00000000..f525152e --- /dev/null +++ b/test/plugins/wasi_crypto/notimplement.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, NotImplement) { + WASI_CRYPTO_EXPECT_FAILURE( + symmetricKeyGenerateManaged(1, "SHA-256"sv, std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyStoreManaged(1, 1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyReplaceManaged(1, 1, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyId(1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyFromId(1, {}, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + + WASI_CRYPTO_EXPECT_FAILURE( + keypairGenerateManaged(1, __WASI_ALGORITHM_TYPE_SIGNATURES, + "ECDSA_P256_SHA256"sv, std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairStoreManaged(1, 1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairReplaceManaged(1, 1, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairId(1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairFromId(1, {}, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + + WASI_CRYPTO_EXPECT_FAILURE(secretsManagerOpen(std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(secretsManagerClose(InvaildHandle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(secretsManagerInvalidate(InvaildHandle, {}, 0), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/signatures.cpp b/test/plugins/wasi_crypto/signatures.cpp new file mode 100644 index 00000000..6d814992 --- /dev/null +++ b/test/plugins/wasi_crypto/signatures.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "helper.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Signatures) { + // Use the generated data to sign and verify. + auto SigTest = [this](__wasi_algorithm_type_e_t AlgType, + std::string_view Alg) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_SUCCESS(KpHandle, + keypairGenerate(AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(StateHandle, signatureStateOpen(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(signatureStateUpdate(StateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_TRUE(signatureStateUpdate(StateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_SUCCESS(SigHandle, signatureStateSign(StateHandle)); + WASI_CRYPTO_EXPECT_TRUE(signatureStateClose(StateHandle)); + + WASI_CRYPTO_EXPECT_SUCCESS(PkHandle, keypairPublickey(KpHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(VerifictionStateHandle, + signatureVerificationStateOpen(PkHandle)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateUpdate(VerifictionStateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateUpdate(VerifictionStateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateVerify(VerifictionStateHandle, SigHandle)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateClose(VerifictionStateHandle)); + }; + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_P256_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_K256_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_P384_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_3072_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_3072_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_4096_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_2048_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_2048_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_2048_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_3072_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_3072_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_4096_SHA512"sv); + + // Verify that a generated public key is well-formed. + auto PublicKeyVerifyTest = [this](__wasi_algorithm_type_e_t AlgType, + std::string_view Alg) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_SUCCESS(KpHandle, + keypairGenerate(AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(PkHandle, keypairPublickey(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(publickeyVerify(PkHandle)); + }; + PublicKeyVerifyTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv); + + // A raw public key with an invalid length is rejected on import. + auto PublicKeyImportInvalidTest = [this](__wasi_algorithm_type_e_t AlgType, + std::string_view Alg) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_FAILURE( + publickeyImport(AlgType, Alg, "00"_u8v, __WASI_PUBLICKEY_ENCODING_RAW), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + }; + PublicKeyImportInvalidTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv); + + auto SigEncodingTest = + [this]( + std::string_view Alg, + std::map<__wasi_signature_encoding_e_t, std::vector> Data) { + SCOPED_TRACE(Alg); + for (auto &[Encoding, Sig] : Data) { + SCOPED_TRACE(Encoding); + WASI_CRYPTO_EXPECT_SUCCESS(SigHandle, + signatureImport(Alg, Sig, Encoding)); + WASI_CRYPTO_EXPECT_SUCCESS(ExportSigHandle, + signatureExport(SigHandle, Encoding)); + WASI_CRYPTO_EXPECT_SUCCESS(ExportSigSize, + arrayOutputLen(ExportSigHandle)); + std::vector ExportSig(ExportSigSize); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(ExportSigHandle, ExportSig)); + EXPECT_EQ(Sig, ExportSig); + } + }; + SigEncodingTest( + "ECDSA_K256_SHA256"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "9D92E9FDCA3DDF2E1DDCA1E3B7A79A250B6E4AFFCABF5F9FF4D960B152AB8300E9EB978BD3DA89C42BBFE5A2C2AEB0AF1DD178FB4BCD0833B587D118F59BBB4D"_u8v}, + {__WASI_SIGNATURE_ENCODING_DER, + "30460221009d92e9fdca3ddf2e1ddca1e3b7a79a250b6e4affcabf5f9f" + "f4d960b152ab8300022100e9eb978bd3da89c42bbfe5a2c2aeb0af1dd1" + "78fb4bcd0833b587d118f59bbb4d"_u8v}}); + SigEncodingTest( + "ECDSA_P256_SHA256"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "80D5D4769AE4F3998DD6B8B01177DE855204122A361F2189F9567C806DE2673E2FBFD3FF018338875B1D144F583EB6E8DC16CF6EEB2BB5C19A3202464ABB58BD"_u8v}, + {__WASI_SIGNATURE_ENCODING_DER, + "304502210080d5d4769ae4f3998dd6b8b01177de855204122a361f2189" + "f9567c806de2673e02202fbfd3ff018338875b1d144f583eb6e8dc16cf" + "6eeb2bb5c19a3202464abb58bd"_u8v}}); + SigEncodingTest( + "Ed25519"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "d4fbdb52bfa726b44d1786a8c0d171c3e62ca83c9e5bbe63de0bb2483f8fd6cc1429ab72cafc41ab56af02ff8fcc43b99bfe4c7ae940f60f38ebaa9d311c4007"_u8v}}); + SigEncodingTest( + "RSA_PSS_2048_SHA256"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "4f01e0c12b08625ecac89a69231906edf826380f37c959a96690d046316d68ff" + "ce9d5c471694fcebfc6b45534864689256e4fc81c78e583f675d0c94b4496474" + "51e81beff01a11a516d5e5ce3f1a910437cb8a3a5096b19fb15f4524a35b23d8" + "9cdba12cf5b71aac1047b28c562df7c5542c34ce23a182cf7e0e231934b17294" + "799d44877a1d68ef1b8f073619b7618e6b7c22db20030d98cf591ffc3d4da5f5" + "8613ecd5ecfc3b40a1d02f40891ca43695cd4c088b05a8054c89c595a47e2748" + "16f35384226f74459ee63e25a1bfc03c360490552ec38343f8ace502f065303b" + "00bc0ec320711b211fde92e57feb9013c3609342495ec0d7cabdec21e54acc38"_u8v}}); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_logging/CMakeLists.txt b/test/plugins/wasi_logging/CMakeLists.txt new file mode 100644 index 00000000..923363a2 --- /dev/null +++ b/test/plugins/wasi_logging/CMakeLists.txt @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasiLoggingTests + wasi_logging.cpp +) + +add_dependencies(wasiLoggingTests + wasmedgePluginWasiLogging +) + +target_include_directories(wasiLoggingTests + PUBLIC + $ + $ +) + +target_link_libraries(wasiLoggingTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasiLoggingTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasiLoggingTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasiLoggingTests wasiLoggingTests) diff --git a/test/plugins/wasi_logging/wasi_logging.cpp b/test/plugins/wasi_logging/wasi_logging.cpp new file mode 100644 index 00000000..b523a177 --- /dev/null +++ b/test/plugins/wasi_logging/wasi_logging.cpp @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "plugin/wasi_logging/func.h" +#include "plugin/wasi_logging/module.h" + +#include "common/defines.h" +#include "runtime/instance/module.h" + +#include +#include +#include + +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + // The built-in plugins are loaded when loading from default paths. + WasmEdge::Plugin::Plugin::loadFromDefaultPaths(); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_logging"sv)) { + if (const auto *Module = Plugin->findModule("wasi:logging/logging"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, std::string_view Str) noexcept { + char *Buf = MemInst.getPointer(Offset); + std::copy_n(Str.data(), Str.length(), Buf); +} + +} // namespace + +TEST(WasiLoggingTests, func_log) { + using namespace std::literals::string_view_literals; + // Create the wasi-logging module instance. + // Create two wasi-logging modules for testing multiple modules. + auto WasiLoggingMod1 = createModule(); + ASSERT_TRUE(WasiLoggingMod1); + auto WasiLoggingMod2 = createModule(); + ASSERT_TRUE(WasiLoggingMod2); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + EXPECT_NE(MemInstPtr, nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + // Set strings in memory. + fillMemContent(MemInst, 0, "stdout"sv); + fillMemContent(MemInst, 8, "stderr"sv); + fillMemContent(MemInst, 16, "out.log"sv); + fillMemContent(MemInst, 24, "out2.log"sv); + fillMemContent(MemInst, 128, "This is log message"sv); + fillMemContent(MemInst, 160, "Message 1 to file"sv); + fillMemContent(MemInst, 192, "Message 2 to file"sv); + fillMemContent(MemInst, 224, "Message 3 to file"sv); + + // Get the function "log". + auto *FuncInst1 = WasiLoggingMod1->findFuncExports("log"); + auto *FuncInst2 = WasiLoggingMod2->findFuncExports("log"); + EXPECT_NE(FuncInst1, nullptr); + EXPECT_NE(FuncInst2, nullptr); + EXPECT_TRUE(FuncInst1->isHostFunction()); + EXPECT_TRUE(FuncInst2->isHostFunction()); + auto &HostFuncInst1 = dynamic_cast( + FuncInst1->getHostFunc()); + auto &HostFuncInst2 = dynamic_cast( + FuncInst2->getHostFunc()); + + // Show All Level + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(2), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(4), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(5), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + + // Stderr Context + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(8), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(8), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + + // Log to out.txt: message 1 + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(160), UINT32_C(17)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(160), UINT32_C(17)}, + {})); + // Log to out2.txt: message 2 + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(24), UINT32_C(8), UINT32_C(192), UINT32_C(17)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(24), UINT32_C(8), UINT32_C(192), UINT32_C(17)}, + {})); + // Log to out.txt: message 3 + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(224), UINT32_C(17)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(224), UINT32_C(17)}, + {})); + + // UnKnown Level + EXPECT_FALSE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_FALSE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt new file mode 100644 index 00000000..23cb850d --- /dev/null +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasiNNTests + wasi_nn.cpp +) + +function(download URL OUTPUT HASH) + file(DOWNLOAD + ${URL} + ${OUTPUT} + SHOW_PROGRESS + EXPECTED_HASH ${HASH} + ) +endfunction() + +# Prepare the testing data for each backends. +foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) + string(TOLOWER ${BACKEND} BACKEND) + if(BACKEND MATCHES "openvino") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures") + download( + https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/mobilenet.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.bin + MD5=ae096b1f735f1e8e54bac8b2a42303bd + ) + download( + https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/mobilenet.xml + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.xml + MD5=4ea3a14273587ce5c1662018878f9f90 + ) + download( + https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/tensor-1x224x224x3-f32.bgr + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/tensor-1x224x224x3-f32.bgr + MD5=bfca546f4a3b5e6da49b7bd728e2799a + ) + elseif(BACKEND MATCHES "pytorch") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures") + download( + https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/pytorch-mobilenet-image/mobilenet.pt + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/mobilenet.pt + MD5=234f446d2446e0f6fd8ed700c0b4b63b + ) + download( + https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/pytorch-mobilenet-image/image-1x3x224x224.rgb + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/image-1x3x224x224.rgb + MD5=551caa6f3b66c1d953655228462570a1 + ) + elseif(BACKEND STREQUAL "tensorflowlite") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures") + download( + https://raw.githubusercontent.com/gusye1234/WasmEdge-WASINN-examples/demo-tflite-image/tflite-birds_v1-image/lite-model_aiy_vision_classifier_birds_V1_3.tflite + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/lite-model_aiy_vision_classifier_birds_V1_3.tflite + MD5=3e59cc3a99afeeb819c2c38b319a7938 + ) + download( + https://raw.githubusercontent.com/gusye1234/WasmEdge-WASINN-examples/demo-tflite-image/tflite-birds_v1-image/birdx224x224x3.rgb + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/birdx224x224x3.rgb + MD5=ad51c39cfe35d2ef35c4052b78cb3c55 + ) + elseif(BACKEND STREQUAL "ggml") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures") + if (CMAKE_SYSTEM_PROCESSOR MATCHES "s390x") + # Use a big endian model for s390x + download( + https://huggingface.co/taronaeo/Granite-3.0-1B-A400M-Instruct-BE-GGUF/resolve/main/granite-3.0-1b-a400m-instruct-be.Q2_K.gguf + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/granite-3.gguf + MD5=2520fd8a702468942fdd1595b9dca9a2 + ) + else() + download( + https://huggingface.co/TheBloke/orca_mini_v3_7B-GGUF/resolve/main/orca_mini_v3_7b.Q2_K.gguf + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca_mini.gguf + MD5=f895f00678bfbf89f70d6d25f20a7b5f + ) + endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + target_compile_options(wasiNNTests PUBLIC + /wd4067 # unexpected tokens following preprocessor directive - expected a newline + /wd4505 # unreferenced local function has been removed + ) + else() + # string_split in common.h unused + target_compile_options(wasiNNTests PUBLIC + -Wno-unused-function + ) + endif() + elseif(BACKEND STREQUAL "piper") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures") + download( + https://github.com/OHF-Voice/piper1-gpl/raw/v1.3.0/tests/test_voice.onnx + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/test_voice.onnx + SHA256=1c8bbb420741358f0a356bb83eaae1b4161fbb5974f6941e10eb5a1725d78994 + ) + download( + https://github.com/OHF-Voice/piper1-gpl/raw/v1.3.0/tests/test_voice.onnx.json + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/test_voice.onnx.json + SHA256=ccd28e02c334fbcfc94a86c8f86f1d7dbb5bffc844af9f22243a1f9f7840db1b + ) + set(ESPEAK_SOURCE_DIR "") + + if (DEFINED PIPER_ROOT) + set(ESPEAK_SOURCE_DIR "${PIPER_ROOT}/espeak-ng-data") + elseif(EXISTS "${CMAKE_BINARY_DIR}/espeak_ng-install/share/espeak-ng-data") + set(ESPEAK_SOURCE_DIR "${CMAKE_BINARY_DIR}/espeak_ng-install/share/espeak-ng-data") + endif() + + if (EXISTS "${ESPEAK_SOURCE_DIR}") + file( + COPY ${ESPEAK_SOURCE_DIR} + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures + ) + else() + message(WARNING "Could not find espeak-ng-data at ${ESPEAK_SOURCE_DIR}") + endif() + if(DEFINED PIPER_ROOT) + find_library(ESPEAK_NG_LIB + NAMES espeak-ng libespeak-ng + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT ESPEAK_NG_LIB) + find_library(ESPEAK_NG_LIB NAMES espeak-ng libespeak-ng) + endif() + find_library(UCD_LIB + NAMES ucd libucd + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT UCD_LIB) + find_library(UCD_LIB NAMES ucd libucd) + endif() + set(ESPEAK_TARGETS ${ESPEAK_NG_LIB} ${UCD_LIB}) + + else() + set(ESPEAK_TARGETS "") + endif() + + target_link_libraries(wasiNNTests + PRIVATE + onnxruntime + ${ESPEAK_TARGETS} + ) + elseif(BACKEND STREQUAL "whisper") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures") + download( + https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures/ggml-base.bin + MD5=4279db3d7b18d9f6e4d5817a16af4f09 + ) + download( + https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/wasmedge-ggml/whisper/test.wav + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures/test.wav + MD5=6cf3f7af1ebbd6b29c373e526b548dba + ) + elseif(BACKEND STREQUAL "mlx") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_mlx_fixtures") + download( + https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/model.safetensors + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_mlx_fixtures/model.safetensors + MD5=59e1605b3af5f1673eb8396251d6bc46 + ) + download( + https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_mlx_fixtures/tokenizer.json + MD5=c9dc953a24ad2b76b4bae4bf456f18bd + ) + target_compile_options(wasiNNTests PUBLIC + -Wno-unused-parameter + ) + + elseif(BACKEND STREQUAL "bitnet") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_bitnet_fixtures") + download( + https://huggingface.co/microsoft/bitnet-b1.58-2B-4T-gguf/resolve/main/ggml-model-i2_s.gguf + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_bitnet_fixtures/ggml-model-i2_s.gguf + MD5=65cb04366e4d02ccd78b4b7b48c84b3b + ) + if(NOT CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + target_compile_options(wasiNNTests PUBLIC + -Wno-unused-function + ) + endif() + else() + # Add the other backend test files fetching here. + endif() +endforeach() + +add_dependencies(wasiNNTests + wasmedgePluginWasiNN +) + +include(WASINNDeps) +wasmedge_setup_wasinn_target(wasiNNTests) + +target_include_directories(wasiNNTests + PUBLIC + $ + $ +) + +target_link_libraries(wasiNNTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +if(TARGET build_info) + target_link_libraries(wasiNNTests + PRIVATE + build_info + ) +endif() + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasiNNTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasiNNTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasiNNTests wasiNNTests) + +if(WASMEDGE_BUILD_WASI_NN_RPC) + add_definitions(-DWASMEDGE_BUILD_WASI_NN_RPC) + target_link_libraries(wasiNNTests + PRIVATE + wasiNNRPC + ) +endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp new file mode 100644 index 00000000..514a0923 --- /dev/null +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -0,0 +1,3465 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinnfunc.h" +#include "wasinnmodule.h" + +#include "common/types.h" +#include "runtime/instance/module.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace std::literals; +using WasmEdge::Host::WASINN::Backend; +using WasmEdge::Host::WASINN::Device; +using WasmEdge::Host::WASINN::ErrNo; +using WasmEdge::Host::WASINN::TensorType; + +#if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET) +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr +createModule(std::string_view NNRPCURI = "") { + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasi_nn/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { + WasmEdge::PO::ArgumentParser Parser; + Plugin->registerOptions(Parser); + if (NNRPCURI != "") { + Parser.set_raw_value("nn-rpc-uri"sv, std::string(NNRPCURI)); + } + if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { + return dynamicPointerCast(Module->create()); + } + } + return {}; +} + +inline std::vector readEntireFile + [[maybe_unused]] (const std::string &Path) { + std::ifstream Fin(Path, std::ios::in | std::ios::binary | std::ios::ate); + if (!Fin) { + return {}; + } + std::vector Buf(static_cast(Fin.tellg())); + Fin.seekg(0, std::ios::beg); + if (!Fin.read(reinterpret_cast(Buf.data()), + static_cast(Buf.size()))) { + return {}; + } + Fin.close(); + return Buf; +} + +template +void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); +} + +void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Value, uint32_t &Ptr) { + MemInst.storeValue(Value, Ptr); + Ptr += 4; +} + +void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t PtrVal, uint32_t PtrSize, uint32_t &Ptr) { + writeUInt32(MemInst, PtrVal, Ptr); + writeUInt32(MemInst, PtrSize, Ptr); +} + +#if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) +template +std::vector classSort(WasmEdge::Span Array) { + std::vector Indices(Array.size()); + std::iota(Indices.begin(), Indices.end(), 0); + std::sort(Indices.begin(), Indices.end(), + [&Array](size_t Left, size_t Right) -> bool { + // Sort indices according to the corresponding array elements. + return Array[Left] > Array[Right]; + }); + return Indices; +} +#endif +} // namespace +#endif + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO +TEST(WasiNNTest, OpenVINOBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::vector TensorDataLegecy = + readEntireFile("./wasinn_openvino_fixtures/tensor-1x224x224x3-f32.bgr"); + std::vector XmlRead = + readEntireFile("./wasinn_openvino_fixtures/mobilenet.xml"); + std::vector WeightRead = + readEntireFile("./wasinn_openvino_fixtures/mobilenet.bin"); + + // Convert the NHWC to NCHW format. + // For historical reasons, the OpenVINO model expects the input tensor in + // NCHW format, while the input tensor is in NHWC format. + // https://github.com/intel/openvino-rs/blob/v0.3.3/crates/openvino/tests/fixtures/mobilenet/build.sh#L39 + // https://github.com/intel/openvino-rs/blob/v0.8.0/crates/openvino/tests/classify-mobilenet.rs#L34 + ASSERT_EQ(TensorDataLegecy.size(), 3 * 224 * 224 * 4); + std::vector TensorData(TensorDataLegecy.size()); + + for (size_t C = 0; C < 3; ++C) { + for (size_t H = 0; H < 224; ++H) { + for (size_t W = 0; W < 224; ++W) { + size_t Loc = H * 224 + W; + for (size_t B = 0; B < 4; ++B) { + TensorData[(C * 224 * 224 + Loc) * 4 + B] = + TensorDataLegecy[(3 * Loc + C) * 4 + B]; + } + } + } + } + + std::vector TensorDim{1, 3, 224, 224}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Temp. values. + std::vector NNGraphTmp; + std::vector NNContextTmp; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // OpenVINO WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::RuntimeError)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- OpenVINO model xml ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, static_cast(XmlRead.size()), + BuilderPtr); + writeFatPointer(MemInst, StorePtr + static_cast(XmlRead.size()), + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- OpenVINO model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(XmlRead.size()), + BuilderPtr); + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong builder count. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(XmlRead.size()), + BuilderPtr); + writeFatPointer(MemInst, StorePtr + static_cast(XmlRead.size()), + static_cast(WeightRead.size()), BuilderPtr); + writeBinaries(MemInst, XmlRead, StorePtr); + writeBinaries(MemInst, WeightRead, StorePtr + XmlRead.size()); + StorePtr += (XmlRead.size() + WeightRead.size()); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(4), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::AUTO), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load -- load second graph. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // OpenVINO WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.emplace_back(Backend::OpenVINO); + NNGraphTmp.back().setReady(); + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: init_execution_context -- graph id exceeds. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- initialize the second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // OpenVINO WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Swap to the tmp. env. + NNContextTmp.emplace_back(0, NNGraphTmp[0]); + NNContextTmp.back().setReady(); + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- empty context. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: set_input -- input index exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- tensor type not FP32. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // OpenVINO WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: compute -- empty context. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::RuntimeError)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // OpenVINO WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output index exceeds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(4004)); + const auto OutputClassification = + MemInst.getSpan(StorePtr, 1001).subspan(1); + std::vector SortedIndex, CorrectClasses{963, 762, 909, 926, 567}; + SortedIndex = classSort(OutputClassification); + // The probability of class i is placed at buffer[i]. + for (size_t I = 0; I < CorrectClasses.size(); I++) { + EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); + } + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +TEST(WasiNNTest, PyTorchBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::vector TensorData = + readEntireFile("./wasinn_pytorch_fixtures/image-1x3x224x224.rgb"); + std::vector WeightRead = + readEntireFile("./wasinn_pytorch_fixtures/mobilenet.pt"); + + std::vector TensorDim{1, 3, 224, 224}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Temp. values. + std::vector NNGraphTmp; + std::vector NNContextTmp; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Torch WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- Torch model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong builder count. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), + BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::AUTO), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load -- load second graph. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Torch WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.emplace_back(Backend::PyTorch); + NNGraphTmp.back().setReady(); + // Test: init_execution_context -- graph id exceeds. + // TODO: add a non-null test for PyTorch. + // NNGraphTmp.swap(NNMod->getEnv().NNGraph); + // NNContextTmp.swap(NNMod->getEnv().NNContext); + // { + // EXPECT_TRUE(HostFuncInit.run( + // CallFrame, + // std::initializer_list{UINT32_C(0), + // BuilderPtr}, Errno)); + // EXPECT_EQ(Errno[0].get(), + // static_cast(ErrNo::MissingMemory)); + // } + // Swap back. + // NNGraphTmp.swap(NNMod->getEnv().NNGraph); + // NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- initialize the second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Torch WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + NNContextTmp.emplace_back(0, NNGraphTmp[0]); + NNContextTmp.back().setReady(); + + // Test: set_input -- tensor type not FP32. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // Torch WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: compute -- empty context. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Torch WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output index exceeds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(4000)); + const auto OutputClassification = + MemInst.getSpan(StorePtr, 1000); + std::vector SortedIndex, CorrectClasses{954, 940, 951, 950, 953}; + SortedIndex = classSort(OutputClassification); + // The probability of class i is placed at buffer[i]. + for (size_t I = 0; I < CorrectClasses.size(); I++) { + EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); + } + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +TEST(WasiNNTest, TFLiteBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::vector TensorData = + readEntireFile("./wasinn_tflite_fixtures/birdx224x224x3.rgb"); + std::vector WeightRead = + readEntireFile("./wasinn_tflite_fixtures/" + "lite-model_aiy_vision_classifier_birds_V1_3.tflite"); + spdlog::info("Read {}"sv, TensorData.size()); + std::vector TensorDim{1, 224, 224, 3}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Temp. values. + std::vector NNGraphTmp; + std::vector NNContextTmp; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Torch WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong builder count. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), + BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), + static_cast(Backend::TensorflowLite), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + static_cast(Device::AUTO), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load -- load second graph. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + // Test: init_execution_context -- graph id exceeds. + NNGraphTmp.emplace_back(Backend::TensorflowLite); + NNGraphTmp.back().setReady(); + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- initialize the second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Torch WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + NNContextTmp.emplace_back(0, NNGraphTmp[0]); + NNContextTmp.back().setReady(); + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + // Tensor type U8 + writeUInt32(MemInst, UINT32_C(3), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // Torch WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: compute -- empty context. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + // WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output index exceeds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(965)); + const auto OutputClassification = + MemInst.getSpan(StorePtr, 965); + std::vector SortedIndex, CorrectClasses{166, 158, 34, 778, 819}; + // FIXME: classSort causing segmentation fault + SortedIndex = classSort(OutputClassification); + + // The probability of class i is placed at buffer[i]. + for (size_t I = 0; I < CorrectClasses.size(); I++) { + EXPECT_EQ(OutputClassification[SortedIndex[I]], + OutputClassification[CorrectClasses[I]]); + } + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +TEST(WasiNNTest, GGMLBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Prompt = "Once upon a time, "; + std::vector TensorData(Prompt.begin(), Prompt.end()); + std::string Model = WasmEdge::Endian::native == WasmEdge::Endian::little + ? "./wasinn_ggml_fixtures/orca_mini.gguf" + : "./wasinn_ggml_fixtures/granite-3.gguf"; + std::vector WeightRead = readEntireFile(Model); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // GGML WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- GGML model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong metadata encoding when builders length > 1. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), + BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += static_cast(WeightRead.size()); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidEncoding)); + } + + // Test: load -- load successfully. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, + StorePtr + + static_cast(TensorDim.size()) * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += static_cast(TensorDim.size() * 4 + TensorData.size()); + + // GGML WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: compute -- compute until finish or context full. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_TRUE( + Errno[0].get() == static_cast(ErrNo::Success) || + Errno[0].get() == static_cast(ErrNo::ContextFull)); + } + + // GGML WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } +} +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +TEST(WasiNNTest, GGMLBackendWithRPC) { + // wasi_nn_rpcserver has to be started outside this test, + // and the URI has to be set to $WASI_NN_RPC_TEST_URI. + // nn-preload has to be specified for "default". + /* + DIR=/tmp/build + export WASI_NN_RPC_TEST_URI=unix://${DIR}/wasi_nn_rpc.sock + export WASMEDGE_PLUGIN_PATH=${DIR}/plugins/wasi_nn + ${DIR}/tools/wasmedge/wasi_nn_rpcserver \ + --nn-rpc-uri=$WASI_NN_RPC_TEST_URI \ + --nn-preload=default:GGML:AUTO:${DIR}/test/plugins/wasi_nn/wasinn_ggml_fixtures/orca_mini.gguf + */ + const auto NNRPCURI = ::getenv("WASI_NN_RPC_TEST_URI"); + if (NNRPCURI == nullptr) { + GTEST_SKIP() << "WASI_NN_RPC_TEST_URI is unset"; + } + + // Create the wasi_nn module instance. + auto NNMod = createModule(NNRPCURI); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + std::string Prompt = "Once upon a time, "; + std::vector TensorData(Prompt.begin(), Prompt.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load_by_name". + auto FuncInst = NNMod->findFuncExports("load_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByName = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "load_by_name_with_config". + FuncInst = NNMod->findFuncExports("load_by_name_with_config"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByNameWithConfig = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Test: load_by_name -- load successfully. + { + std::string Name = "default"; + std::vector NameVec(Name.begin(), Name.end()); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + EXPECT_TRUE(HostFuncLoadByName.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load_by_name_with_config -- load successfully. + { + std::string Name = "default"; + std::string Config = "{}"; + std::vector NameVec(Name.begin(), Name.end()); + std::vector ConfigVec(Config.begin(), Config.end()); + uint32_t ConfigPtr = LoadEntryPtr + NameVec.size(); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + writeBinaries(MemInst, ConfigVec, ConfigPtr); + EXPECT_TRUE(HostFuncLoadByNameWithConfig.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), ConfigPtr, + static_cast(ConfigVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // GGML WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: compute -- compute until finish or context full. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + // FIXME: ErrNo propagation is not supported yet + // EXPECT_TRUE( + // Errno[0].get() == static_cast(ErrNo::Success) + // || Errno[0].get() == + // static_cast(ErrNo::ContextFull)); + } + + // GGML WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } +} + +TEST(WasiNNTest, GGMLBackendComputeSingleWithRPC) { + // wasi_nn_rpcserver has to be started outside this test, + // and the URI has to be set to $WASI_NN_RPC_TEST_URI. + // nn-preload has to be specified for "default". + /* + DIR=/tmp/build + export WASI_NN_RPC_TEST_URI=unix://${DIR}/wasi_nn_rpc.sock + export WASMEDGE_PLUGIN_PATH=${DIR}/plugins/wasi_nn + ${DIR}/tools/wasmedge/wasi_nn_rpcserver \ + --nn-rpc-uri=$WASI_NN_RPC_TEST_URI \ + --nn-preload=default:GGML:AUTO:${DIR}/test/plugins/wasi_nn/wasinn_ggml_fixtures/orca_mini.gguf + */ + const auto NNRPCURI = ::getenv("WASI_NN_RPC_TEST_URI"); + if (NNRPCURI == nullptr) { + GTEST_SKIP() << "WASI_NN_RPC_TEST_URI is unset"; + } + + // Create the wasmedge_process module instance. + auto NNMod = createModule(NNRPCURI); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + std::string Prompt = "Once upon a time, "; + std::vector TensorData(Prompt.begin(), Prompt.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load_by_name". + auto FuncInst = NNMod->findFuncExports("load_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByName = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "load_by_name_with_config". + FuncInst = NNMod->findFuncExports("load_by_name_with_config"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByNameWithConfig = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output_single"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutputSingle = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute_single"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncComputeSingle = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: load_by_name -- load successfully. + { + std::string Name = "default"; + std::vector NameVec(Name.begin(), Name.end()); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + EXPECT_TRUE(HostFuncLoadByName.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load_by_name_with_config -- load successfully. + { + std::string Name = "default"; + std::string Config = "{}"; + std::vector NameVec(Name.begin(), Name.end()); + std::vector ConfigVec(Config.begin(), Config.end()); + uint32_t ConfigPtr = LoadEntryPtr + NameVec.size(); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + writeBinaries(MemInst, ConfigVec, ConfigPtr); + EXPECT_TRUE(HostFuncLoadByNameWithConfig.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), ConfigPtr, + static_cast(ConfigVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // GGML WASI-NN compute_single tests. + // Test: compute_single -- context id exceeds. + { + EXPECT_TRUE(HostFuncComputeSingle.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: compute_single -- call compute_single once follow by a + // get_output_single. + { + // compute_single + EXPECT_TRUE(HostFuncComputeSingle.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // get_output_single + EXPECT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // GGML WASI-NN get_output_single tests. + // Test: get_output_single -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } +} +#endif // WASMEDGE_BUILD_WASI_NN_RPC +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER +TEST(WasiNNTest, WhisperBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + std::vector TensorData = + readEntireFile("./wasinn_whisper_fixtures/test.wav"); + std::vector WeightRead = + readEntireFile("./wasinn_whisper_fixtures/ggml-base.bin"); + std::vector TensorDim{1, static_cast(TensorData.size())}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Whisper WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- Whisper model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- load successfully. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Whisper WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- initialize the second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Whisper WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // Whisper WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Whisper WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +TEST(WasiNNTest, PiperBackend) { + // Create the wasmedge_process module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Text = "This is a test."; + std::vector TensorData(Text.begin(), Text.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Piper WASI-NN load tests. + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- Piper config ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, 1, BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong config encoding. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, 0, BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidEncoding)); + } + + // Test: load -- load successfully. + std::string Config = + "{\"model\": \"./wasinn_piper_fixtures/test_voice.onnx\", " + "\"espeak_data\": \"./wasinn_piper_fixtures/piper/espeak-ng-data\"}"; + std::vector ConfigData(Config.begin(), Config.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ConfigData.size(), BuilderPtr); + writeBinaries(MemInst, ConfigData, StorePtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Piper WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Piper WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, 3, BuilderPtr); + writeFatPointer(MemInst, + StorePtr + TensorDim.size() * + sizeof(decltype(TensorDim)::value_type), + TensorData.size(), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries( + MemInst, TensorData, + StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type)); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += TensorDim.size() * sizeof(decltype(TensorDim)::value_type) + + TensorData.size(); + + // Piper WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Piper WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 10000 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 10000); + } + + // Piper json_input tests. + // Test: load -- load successfully. + Config = + "{\"model\": " + "\"./wasinn_piper_fixtures/test_voice.onnx\",\"espeak_data\": " + "\"./wasinn_piper_fixtures/piper/espeak-ng-data\",\"json_input\":true}"; + ConfigData = {Config.begin(), Config.end()}; + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ConfigData.size(), BuilderPtr); + writeBinaries(MemInst, ConfigData, StorePtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Test: init_execution_context -- initialize context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // First JSON input with parameters overridden. + Text = "{\"text\": \"This is a test.\", \"noise_scale\": 0.0, " + "\"length_scale\": 2.0, \"noise_w\": 0.0}"; + TensorData = {Text.begin(), Text.end()}; + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, 3, BuilderPtr); + writeFatPointer(MemInst, + StorePtr + TensorDim.size() * + sizeof(decltype(TensorDim)::value_type), + TensorData.size(), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries( + MemInst, TensorData, + StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type)); + + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += TensorDim.size() * sizeof(decltype(TensorDim)::value_type) + + TensorData.size(); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + // Should output more than 40000 bytes. + EXPECT_GE(BytesWritten, 40000); + } + + // Second JSON input to check if one-time overriding works properly. + Text = "{\"text\": \"This is a test.\", \"output_type\": \"raw\", " + "\"noise_scale\": 0.0, \"noise_w\": 0.0}"; + TensorData = {Text.begin(), Text.end()}; + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, 3, BuilderPtr); + writeFatPointer(MemInst, + StorePtr + TensorDim.size() * + sizeof(decltype(TensorDim)::value_type), + TensorData.size(), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries( + MemInst, TensorData, + StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type)); + + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += TensorDim.size() * sizeof(decltype(TensorDim)::value_type) + + TensorData.size(); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 30000); + // Should output less than 50000 bytes. + EXPECT_LT(BytesWritten, 50000); + EXPECT_EQ(BytesWritten, 44100); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +TEST(WasiNNTest, ChatTTSBackend) { + // Create the wasmedge_process module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Prompt = "This is test prompt."; + std::vector TensorData(Prompt.begin(), Prompt.end()); + std::string config = + "{\"prompt\":\"[oral_2][laugh_0][break_6]\",\"spk_emb\":\"random\"," + "\"temperature\":0.5,\"top_k\":0,\"top_p\":0.9}"; + std::vector ConfigData(config.begin(), config.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "unload". + FuncInst = NNMod->findFuncExports("unload"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncUnload = + dynamic_cast(FuncInst->getHostFunc()); + + // ChatTTS WASI-NN load tests. + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::ChatTTS), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::ChatTTS), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + BuilderPtr = LoadEntryPtr; + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::ChatTTS), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + // ChatTTS WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- initialize the second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // ChatTTS WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, static_cast(TensorType::U8), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // ChatTTS WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: setInput -- set metadata successfully. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, static_cast(TensorType::U8), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, ConfigData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, ConfigData, StorePtr + TensorDim.size() * 4); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(1), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + ConfigData.size()); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // ChatTTS WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } + + // ChatTTS WASI-NN unload tests. + // Test: unload -- unload successfully. + { + EXPECT_TRUE(HostFuncUnload.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +TEST(WasiNNTest, MLXBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Prompt = "How are you?"; + std::string Tokenizer = "./wasinn_mlx_fixtures/tokenizer.json"; + std::vector TensorData(Prompt.begin(), Prompt.end()); + std::vector WeightRead = + readEntireFile("./wasinn_mlx_fixtures/model.safetensors"); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // MLX WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- MLX model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- load successfully. + std::string Config = + "{\"model_type\":\"tiny_llama_1.1B_chat_v1.0\", " + "\"tokenizer\":\"" + + Tokenizer + + "\", \"q_bits\": 4, \"group_size\": 128, \"is_quantized\": false}"; + std::vector ConfigData(Config.begin(), Config.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + WeightRead.size(), ConfigData.size(), + BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + writeBinaries(MemInst, ConfigData, StorePtr + WeightRead.size()); + StorePtr += WeightRead.size() + ConfigData.size(); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // MLX WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- initialize the second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // MLX WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // MLX WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // MLX WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET +TEST(WasiNNTest, BitNetBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with a memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // --- Get host functions --- + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "unload". + FuncInst = NNMod->findFuncExports("unload"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncUnload = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute_single". + FuncInst = NNMod->findFuncExports("compute_single"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncComputeSingle = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "get_output_single". + FuncInst = NNMod->findFuncExports("get_output_single"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncGetOutputSingle = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "fini_single". + FuncInst = NNMod->findFuncExports("fini_single"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncFiniSingle = + dynamic_cast(FuncInst->getHostFunc()); + + // --- Test Data & Pointer Setup --- + const std::string ModelPath = "./wasinn_bitnet_fixtures/ggml-model-i2_s.gguf"; + const std::string ModelPreloadStr = "preload:" + ModelPath; + const std::string MetadataStr = R"({"n-predict": 128})"; + const std::string Prompt = "Once upon a time, "; + + uint32_t BuilderPtr = 0; + uint32_t LoadEntryPtr = 0; + uint32_t SetInputEntryPtr = 0; + uint32_t StorePtr = 65536; + const uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + std::array Errno = {0}; + uint32_t GraphId = 0; + uint32_t CtxId = 0; + + // BitNet WASI-NN load tests + // Test: load -- empty builder array. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 0, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, 1, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- model bin ptr out of bounds. + { + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, 10, BuilderPtr); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 1, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- invalid metadata encoding. + { + const std::string InvalidMetadataStr = R"({"n-predict": "not-a-number")"; + std::vector ModelPreloadVec(ModelPreloadStr.begin(), + ModelPreloadStr.end()); + std::vector InvalidMetadataVec(InvalidMetadataStr.begin(), + InvalidMetadataStr.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ModelPreloadVec.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + ModelPreloadVec.size(), + InvalidMetadataVec.size(), BuilderPtr); + writeBinaries(MemInst, ModelPreloadVec, StorePtr); + writeBinaries(MemInst, InvalidMetadataVec, + StorePtr + ModelPreloadVec.size()); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 2, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidEncoding)); + } + // Test: load -- load successfully. + { + std::vector ModelPreloadVec(ModelPreloadStr.begin(), + ModelPreloadStr.end()); + std::vector MetadataVec(MetadataStr.begin(), MetadataStr.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ModelPreloadVec.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + ModelPreloadVec.size(), + MetadataVec.size(), BuilderPtr); + writeBinaries(MemInst, ModelPreloadVec, StorePtr); + writeBinaries(MemInst, MetadataVec, + StorePtr + ModelPreloadVec.size()); + StorePtr += ModelPreloadVec.size() + MetadataVec.size(); + ASSERT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 2, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)) + << "Load failed. Ensure model file exists at: " << ModelPath; + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + GraphId = *MemInst.getPointer(BuilderPtr); + BuilderPtr += 4; + } + + // BitNet WASI-NN init_execution_context tests + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, std::initializer_list{999, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: init_execution_context -- initialize context successfully. + { + ASSERT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{GraphId, BuilderPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + CtxId = *MemInst.getPointer(BuilderPtr); + BuilderPtr += 4; + } + + // BitNet WASI-NN set_input tests + SetInputEntryPtr = BuilderPtr; + { + std::vector PromptData(Prompt.begin(), Prompt.end()); + std::vector PromptDim = { + static_cast(PromptData.size())}; + + writeFatPointer(MemInst, StorePtr, PromptDim.size(), BuilderPtr); + writeUInt32(MemInst, static_cast(TensorType::U8), BuilderPtr); + writeFatPointer(MemInst, StorePtr + PromptDim.size() * 4, PromptData.size(), + BuilderPtr); + writeBinaries(MemInst, PromptDim, StorePtr); + writeBinaries(MemInst, PromptData, + StorePtr + PromptDim.size() * 4); + } + // Test: set_input -- invalid context id. + { + EXPECT_TRUE(HostFuncSetInput.run( + CallFrame, + std::initializer_list{999, 0, SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- invalid tensor index. + { + EXPECT_TRUE(HostFuncSetInput.run( + CallFrame, + std::initializer_list{CtxId, 2, SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + ASSERT_TRUE(HostFuncSetInput.run( + CallFrame, + std::initializer_list{CtxId, 0, SetInputEntryPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // BitNet WASI-NN compute and get_output tests + // Test: compute -- invalid context ID. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{999}, Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + ASSERT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{CtxId}, Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + // Test: get_output -- output buffer pointer out of bounds. + { + EXPECT_TRUE( + HostFuncGetOutput.run(CallFrame, + std::initializer_list{ + CtxId, 0, OutBoundPtr, 5, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- bytes written pointer out of bounds. + { + EXPECT_TRUE( + HostFuncGetOutput.run(CallFrame, + std::initializer_list{ + CtxId, 0, StorePtr, 5, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + uint32_t BytesNeeded = 0; + ASSERT_TRUE( + HostFuncGetOutput.run(CallFrame, + std::initializer_list{ + CtxId, 0, StorePtr, 0, BuilderPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + BytesNeeded = *MemInst.getPointer(BuilderPtr); + EXPECT_GT(BytesNeeded, 10); + } + + // BitNet compute_single tests + { + std::string FullStreamedOutput = ""; + const int MaxStreamTokens = 20; + + // Test: set_input -- set prompt to start a new streaming sequence. + { + ASSERT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + CtxId, 0, SetInputEntryPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: compute_single and get_output_single in a loop. + for (int i = 0; i < MaxStreamTokens; ++i) { + ASSERT_TRUE(HostFuncComputeSingle.run( + CallFrame, std::initializer_list{CtxId}, + Errno)); + if (Errno[0].get() == + static_cast(ErrNo::EndOfSequence)) { + break; + } + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + + uint32_t SingleTokenBytes = 0; + ASSERT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{CtxId, 0, StorePtr, 32, + BuilderPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + SingleTokenBytes = *MemInst.getPointer(BuilderPtr); + if (SingleTokenBytes > 0) { + auto TokenSpan = *MemInst.getBytes(StorePtr, SingleTokenBytes); + FullStreamedOutput += std::string( + reinterpret_cast(TokenSpan.data()), TokenSpan.size()); + } + } + EXPECT_GT(FullStreamedOutput.length(), 10); + + // Test: fini_single -- finalize the streaming session. + { + ASSERT_TRUE(HostFuncFiniSingle.run( + CallFrame, std::initializer_list{CtxId}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + } + + // BitNet WASI-NN unload tests. + // Test: unload -- invalid graph id. + { + EXPECT_TRUE(HostFuncUnload.run( + CallFrame, std::initializer_list{999}, Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: unload -- unload successfully and verify. + { + ASSERT_TRUE(HostFuncUnload.run( + CallFrame, std::initializer_list{GraphId}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{GraphId, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET diff --git a/test/plugins/wasm_bpf/CMakeLists.txt b/test/plugins/wasm_bpf/CMakeLists.txt new file mode 100644 index 00000000..625bfafe --- /dev/null +++ b/test/plugins/wasm_bpf/CMakeLists.txt @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmBpfTests + simple_map_test.cpp + simple_ringbuf_test.cpp + wasm_bpf.cpp +) + +add_subdirectory(assets) + +add_dependencies(wasmBpfTests + wasmedgePluginWasmBpf + wasmBpfTestsAssets +) + +target_include_directories(wasmBpfTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmBpfTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePluginWasmBpf +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmBpfTests + PRIVATE + wasmedgeCAPI + wasmedgeExecutor + ) +else() + target_link_libraries(wasmBpfTests + PRIVATE + wasmedge_shared + wasmedgeExecutor + ) +endif() + +add_test(wasmBpfTests wasmBpfTests) diff --git a/test/plugins/wasm_bpf/assets/.gitignore b/test/plugins/wasm_bpf/assets/.gitignore new file mode 100644 index 00000000..2247a974 --- /dev/null +++ b/test/plugins/wasm_bpf/assets/.gitignore @@ -0,0 +1,4 @@ +bootstrap.bpf.o +runqlat.bpf.o +simple_ringbuf.bpf.o +simple_map.bpf.o diff --git a/test/plugins/wasm_bpf/assets/CMakeLists.txt b/test/plugins/wasm_bpf/assets/CMakeLists.txt new file mode 100644 index 00000000..80749a7e --- /dev/null +++ b/test/plugins/wasm_bpf/assets/CMakeLists.txt @@ -0,0 +1,41 @@ +include(FetchContent) + +# Download wasm-bpf, copy & compile two bpf programs from that +FetchContent_Declare( + wasm_bpf + GIT_REPOSITORY https://github.com/eunomia-bpf/wasm-bpf + GIT_TAG b76be32d44c2ec1933ca28eab875b50e713855b8 +) + +message("Downloading wasm-bpf") +FetchContent_MakeAvailable(wasm_bpf) +message("Downloaded wasm-bpf") + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/bootstrap.bpf.o + COMMAND make -C ${wasm_bpf_SOURCE_DIR}/examples/bootstrap/ bootstrap.bpf.o + COMMAND cp ${wasm_bpf_SOURCE_DIR}/examples/bootstrap/bootstrap.bpf.o ${CMAKE_CURRENT_SOURCE_DIR} + WORKING_DIRECTORY ${wasm_bpf_SOURCE_DIR}/examples/bootstrap/ +) + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/runqlat.bpf.o + COMMAND make -C ${wasm_bpf_SOURCE_DIR}/examples/runqlat/ runqlat.bpf.o + COMMAND cp ${wasm_bpf_SOURCE_DIR}/examples/runqlat/runqlat.bpf.o ${CMAKE_CURRENT_SOURCE_DIR} + WORKING_DIRECTORY ${wasm_bpf_SOURCE_DIR}/examples/runqlat/ +) + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/simple_map.bpf.o ${CMAKE_CURRENT_SOURCE_DIR}/simple_ringbuf.bpf.o + COMMAND make -C ${CMAKE_CURRENT_SOURCE_DIR}/bpf-sources + COMMAND cp ${CMAKE_CURRENT_SOURCE_DIR}/bpf-sources/*.bpf.o ${CMAKE_CURRENT_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +add_custom_target( + wasmBpfTestsAssets + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/bootstrap.bpf.o + ${CMAKE_CURRENT_SOURCE_DIR}/runqlat.bpf.o + ${CMAKE_CURRENT_SOURCE_DIR}/simple_ringbuf.bpf.o + ${CMAKE_CURRENT_SOURCE_DIR}/simple_map.bpf.o +) diff --git a/test/plugins/wasm_bpf/assets/README.md b/test/plugins/wasm_bpf/assets/README.md new file mode 100644 index 00000000..fcee89d2 --- /dev/null +++ b/test/plugins/wasm_bpf/assets/README.md @@ -0,0 +1,13 @@ +# wasm_bpf Plugin tests + +This file contains BPF programs that will be used during testing. + +- `bootstrap` and `runqlat`: examples copied from `wasm-bpf`. See [here](https://github.com/eunomia-bpf/wasm-bpf/tree/main/examples) for build instructions. + +- `simple_ringbuf`: A simple eBPF program that writes fixed data to a ring buffer +- `simple_map`: A simple eBPF program that stores fixed data to a BPF map + +The sources of `simple_ringbuf` and `simple_map` are listed under +`bpf-sources`. Run `make` under that directory to build them. + +`libbpf` and `clang` are required to build them. diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/.gitignore b/test/plugins/wasm_bpf/assets/bpf-sources/.gitignore new file mode 100644 index 00000000..2cf24234 --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/.gitignore @@ -0,0 +1,4 @@ +*.bpf.o +bootstrap.bpf.c +runqlat.bpf.c +bootstrap.h diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/Makefile b/test/plugins/wasm_bpf/assets/bpf-sources/Makefile new file mode 100644 index 00000000..b65b866c --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/Makefile @@ -0,0 +1,20 @@ + +DEL = rm -rf +FILES = $(shell ls *.bpf.c | awk '{split($$0,a,".");print a[1]}') + +ARCH ?= $(shell uname -m | sed 's/x86_64/x86/' | sed 's/aarch64/arm64/' | sed 's/ppc64le/powerpc/' | sed 's/mips.*/mips/') +CLANG_BPF_SYS_INCLUDES = $(shell $(CLANG) -v -E - &1 \ + | sed -n '/<...> search starts here:/,/End of search list./{ s| \(/.*\)|-idirafter \1|p }') + +all: $(FILES) + +$(FILES) : % : %.bpf.c + $(DEL) $@.bpf.o + clang -g -O2 -target bpf -D__TARGET_ARCH_$(ARCH) -c $(filter %.c,$^) -o $@.bpf.o + llvm-strip -g $@.bpf.o + cp $@.bpf.o ../$@.bpf.o + +%.bpf.o : % + +clean: + $(DEL) *.bpf.o diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c new file mode 100644 index 00000000..0e8facb0 --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#define SEC(name) __attribute__((section(name), used)) +#define __uint(name, val) int(*name)[val] +#define __type(name, val) typeof(val) *name + +#define __u64 unsigned long long +#define __u32 unsigned int + +#define u32 __u32 +#define u64 __u64 + +#define BPF_MAP_TYPE_HASH ((u32)1) +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 16); + __type(key, u32); + __type(value, u64); +} test_map SEC(".maps"); +static void *(*bpf_map_lookup_elem)(void *map, const void *key) = (void *)1; +static long (*bpf_map_update_elem)(void *map, const void *key, + const void *value, __u64 flags) = (void *)2; + +static const u32 INDICATING_KEY = 0xABCD; +static const u32 ADD_VALUE_1_KEY = 0xCDEF; +static const u32 ADD_VALUE_2_KEY = 0x1234; +static const u32 RESULT_VALUE_KEY = 0x7890; +SEC("tp_btf/sched_wakeup") +int sched_wakeup(void *ctx) { + // Use an element with key `0xABCD` to indicate that the userspace program + // already set values of the add values. + if (!bpf_map_lookup_elem(&test_map, &INDICATING_KEY)) { + return 0; + } + // Read the two add values from the map + u64 *val1, *val2; + val1 = bpf_map_lookup_elem(&test_map, &ADD_VALUE_1_KEY); + val2 = bpf_map_lookup_elem(&test_map, &ADD_VALUE_2_KEY); + if (!val1 || !val2) + return 0; + u64 result = (u64)(*val1) + (u64)(*val2); + // Store the result + bpf_map_update_elem(&test_map, &RESULT_VALUE_KEY, &result, 0); + return 0; +} + +char LICENSE[] SEC("license") = "Dual BSD/GPL"; diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c new file mode 100644 index 00000000..1e170cc8 --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#define SEC(name) __attribute__((section(name), used)) +#define __uint(name, val) int(*name)[val] + +#define __u64 unsigned long long +#define __u32 unsigned int + +#define u32 __u32 + +#define BPF_MAP_TYPE_RINGBUF ((u32)27) + +char LICENSE[] SEC("license") = "Dual BSD/GPL"; +static void *(*bpf_ringbuf_reserve)(void *ringbuf, __u64 size, + __u64 flags) = (void *)131; + +static void (*bpf_ringbuf_submit)(void *data, __u64 flags) = (void *)132; + +struct { + __uint(type, BPF_MAP_TYPE_RINGBUF); + __uint(max_entries, 256 * 1024); +} rb SEC(".maps"); + +SEC("tp/sched/sched_process_exec") +int handle_exec(void *ctx) { + u32 send_data; + send_data = 0xABCD1234; + + /* reserve sample from BPF ringbuf */ + u32 *e = bpf_ringbuf_reserve(&rb, sizeof(send_data), 0); + if (!e) + return 0; + *e = send_data; + /* successfully submit it to user-space for post-processing */ + bpf_ringbuf_submit(e, 0); + return 0; +} diff --git a/test/plugins/wasm_bpf/simple_map_test.cpp b/test/plugins/wasm_bpf/simple_map_test.cpp new file mode 100644 index 00000000..f8b84e4e --- /dev/null +++ b/test/plugins/wasm_bpf/simple_map_test.cpp @@ -0,0 +1,292 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "executor/executor.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-bpf-map-operate.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "runtime/instance/module.h" +#include "wasm-bpf-module.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasm_bpf/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +std::filesystem::path getAssertsPath() { + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, const std::vector &data) noexcept { + char *buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); +} +} // namespace + +static const uint32_t INDICATING_KEY = 0xABCD; +static const uint32_t ADD_VALUE_1_KEY = 0xCDEF; +static const uint32_t ADD_VALUE_2_KEY = 0x1234; +static const uint32_t RESULT_VALUE_KEY = 0x7890; + +TEST(WasmBpfTest, SimpleMapTest) { + using namespace std::string_view_literals; + // Test loading and attaching a BPF program and some map operations. + auto module = dynamic_cast(createModule()); + ASSERT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + ASSERT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "simple_map.bpf.o"; + + // Ensure the BPF object we need exists. + ASSERT_TRUE(fs::exists(bpfObject)); + + // Read the BPF object into Wasm memory. + std::ifstream bpfObjStream(bpfObject); + ASSERT_TRUE(bpfObjStream.is_open()); + ASSERT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + ASSERT_FALSE(bpfObjectBytes.empty()); + // Offset used to place data in memory. + uint32_t nextOffset = 1; + + // Put the BPF object in memory. + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Write the strings to memory. + std::array strings = { + "test_map", // Map name + "sched_wakeup", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + ASSERT_NE(loadFunc, nullptr); + ASSERT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + ASSERT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + ASSERT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + ASSERT_NE(attachFunc, nullptr); + ASSERT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + ASSERT_TRUE(attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[1]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[2]), + }, + attachResult)); + ASSERT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + ASSERT_NE(mapFdFunc, nullptr); + ASSERT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + ASSERT_TRUE(mapFdFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + ASSERT_GE(mapFd, 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapOptFunc = module->findFuncExports("wasm_bpf_map_operate"); + EXPECT_NE(mapOptFunc, nullptr); + EXPECT_TRUE(mapOptFunc->isHostFunction()); + auto &mapOptFuncHost = + dynamic_cast(mapOptFunc->getHostFunc()); + + // A wrapper to call wasm_bpf_map_operate + auto callMapOperate = [&](int32_t fd, int32_t cmd, uint32_t key, + uint32_t value, uint32_t nextKey, + uint64_t flags) -> int32_t { + std::array callResult; + EXPECT_TRUE(mapOptFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(fd), WasmEdge::ValVariant(cmd), + WasmEdge::ValVariant(key), WasmEdge::ValVariant(value), + WasmEdge::ValVariant(nextKey), WasmEdge::ValVariant(flags)}, + callResult)); + return callResult[0].get(); + }; + + auto mapLookupElem = [&](int32_t fd, uint32_t key, + uint32_t valueOut) -> int32_t { + // key found -> returns 0 + // key not found -> returns -1 + return callMapOperate(fd, + 1, // BPF_MAP_LOOKUP_ELEM + key, valueOut, 0, 0); + }; + + auto mapUpdateElem = [&](int32_t fd, uint32_t key, + uint32_t value) -> int32_t { + return callMapOperate(fd, + 2, // BPF_MAP_UPDATE_ELEM + key, value, 0, 0); + }; + + // Helper functions that make reading and writing more convenient. + auto readU64 = [&](uint32_t offset) -> uint64_t { + const auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + return *ptr; + }; + auto writeU64 = [&](uint32_t offset, uint64_t val) { + auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + + auto writeU32 = [&](uint32_t offset, uint32_t val) { + auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + + // Generate two numbers, which will be stored in the map and summed by the + // eBPF program. + std::mt19937 randGen; + randGen.seed(std::random_device()()); + std::uniform_int_distribution intDist(0, + static_cast(1e16)); + uint64_t num1 = intDist(randGen); + uint64_t num2 = intDist(randGen); + + // Prepare Wasm memory to store numbers. + const uint32_t numOffset1 = nextOffset; + nextOffset += 8; + const uint32_t numOffset2 = nextOffset; + nextOffset += 8; + const uint32_t resultOffset = nextOffset; + nextOffset += 8; + const uint32_t indicatingKeyOffset = nextOffset; + nextOffset += 4; + const uint32_t num1KeyOffset = nextOffset; + nextOffset += 4; + const uint32_t num2KeyOffset = nextOffset; + nextOffset += 4; + const uint32_t resultKeyOffset = nextOffset; + nextOffset += 4; + + writeU32(num1KeyOffset, ADD_VALUE_1_KEY); + writeU32(num2KeyOffset, ADD_VALUE_2_KEY); + writeU32(resultKeyOffset, RESULT_VALUE_KEY); + writeU32(indicatingKeyOffset, INDICATING_KEY); + + writeU32(INDICATING_KEY, indicatingKeyOffset); + + writeU64(numOffset1, num1); + writeU64(numOffset2, num2); + + writeU64(resultOffset, 0); + + // Write the addend values into the map. + ASSERT_EQ(mapUpdateElem(mapFd, num1KeyOffset, numOffset1), 0); + ASSERT_EQ(mapUpdateElem(mapFd, num2KeyOffset, numOffset2), 0); + + // Write the indicator key. + // Arbitrary values are correct. We only care about the existence of the + // indicator key. + ASSERT_EQ(mapUpdateElem(mapFd, indicatingKeyOffset, numOffset1), 0); + + // Sleep for 1s and wait for the eBPF program to process. + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // Read the result and check it + ASSERT_EQ(mapLookupElem(mapFd, resultKeyOffset, resultOffset), 0); + uint64_t addResult = readU64(resultOffset); + ASSERT_EQ(addResult, num1 + num2); + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + ASSERT_NE(closeFunc, nullptr); + ASSERT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + ASSERT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + ASSERT_EQ(closeResult[0].get(), 0); +} diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp new file mode 100644 index 00000000..e2fbf94c --- /dev/null +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "executor/executor.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-buffer-poll.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "runtime/instance/module.h" +#include "wasm-bpf-module.h" + +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasm_bpf/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +std::filesystem::path getAssertsPath() { + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, const std::vector &data) noexcept { + char *buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); +} +class PollCallbackFunction + : public WasmEdge::Runtime::HostFunction { +public: + PollCallbackFunction() {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + uint32_t __attribute__((unused)) ctx, + uint32_t data, uint32_t data_sz) { + using namespace std; + using WasmEdge::unlikely; + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + if (data_sz < static_cast(sizeof(uint32_t))) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + const uint32_t *dataPtr = memory->getSpan(data, 1).data(); + if (unlikely(!dataPtr)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + EXPECT_EQ(*dataPtr, UINT32_C(0xABCD1234)); + return 0; + } +}; + +} // namespace + +TEST(WasmBpfTest, SimpleRingbuf) { + using namespace std::string_view_literals; + // Test loading and attaching a BPF program and polling a buffer. + auto module = dynamic_cast(createModule()); + ASSERT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + ASSERT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "simple_ringbuf.bpf.o"; + + // Ensure the BPF object we need exists. + ASSERT_TRUE(fs::exists(bpfObject)); + + // Read the BPF object into Wasm memory. + std::ifstream bpfObjStream(bpfObject); + ASSERT_TRUE(bpfObjStream.is_open()); + ASSERT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + ASSERT_FALSE(bpfObjectBytes.empty()); + // Offset used to place data in memory. + uint32_t nextOffset = 1; + + // Put the BPF object in memory. + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Write the strings to memory. + std::array strings = { + "rb", // Map name + "handle_exec", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + const uint32_t bufferPollMemoryOffset = nextOffset; + const uint32_t bufferPollSize = 256; + nextOffset += bufferPollSize; + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + ASSERT_NE(loadFunc, nullptr); + ASSERT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + ASSERT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + ASSERT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + ASSERT_NE(attachFunc, nullptr); + ASSERT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + ASSERT_TRUE(attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[1]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[2]), + }, + attachResult)); + ASSERT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + ASSERT_NE(mapFdFunc, nullptr); + ASSERT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + ASSERT_TRUE(mapFdFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + ASSERT_GE(mapFd, 0); + + // In the following steps we prepare for polling. + // Create an instance of the polling callback function. + moduleInst.addHostFunc("__polling_callback_hostfunc"sv, + std::make_unique()); + auto *callbackFuncInst = + moduleInst.findFuncExports("__polling_callback_hostfunc"); + // Create a function table and fill it with the callback function. + auto funcTableInst = + std::make_unique( + WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); + ASSERT_TRUE(funcTableInst->setRefs( + std::initializer_list{callbackFuncInst}, 0, 0, + 1)); + // Add the table to the main module + moduleInst.addHostTable("__indirect_function_table"sv, + std::move(funcTableInst)); + + // Get the "wasm_bpf_buffer_poll" function + auto *bufferPollFunc = module->findFuncExports("wasm_bpf_buffer_poll"); + ASSERT_NE(bufferPollFunc, nullptr); + ASSERT_TRUE(bufferPollFunc->isHostFunction()); + auto &bufferPollFuncHost = dynamic_cast( + bufferPollFunc->getHostFunc()); + + // Call the polling function + std::array pollResult; + for (size_t i = 1; i <= 50; i++) { + using namespace std; + ASSERT_TRUE(bufferPollFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), // object handle + WasmEdge::ValVariant(mapFd), // map fd + UINT32_C(0), // callback function index + UINT32_C(0), // Custom context pointer + WasmEdge::ValVariant(bufferPollMemoryOffset), // buffer offset + WasmEdge::ValVariant(bufferPollSize), // buffer size + UINT32_C(100) // timeout (ms) + }, + pollResult)); + ASSERT_GE(pollResult[0].get(), 0); + } + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + ASSERT_NE(closeFunc, nullptr); + ASSERT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + ASSERT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + ASSERT_EQ(closeResult[0].get(), 0); +} diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp new file mode 100644 index 00000000..34f8b116 --- /dev/null +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -0,0 +1,581 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ast/type.h" +#include "common/defines.h" +#include "executor/executor.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-buffer-poll.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-bpf-map-operate.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "runtime/instance/module.h" +#include "wasm-bpf-module.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasm_bpf/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} + +std::filesystem::path getAssertsPath() { + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; +} +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, uint32_t count, char chr = 0) noexcept { + std::fill_n(memInst.getPointer(offset), count, chr); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, const std::vector &data) noexcept { + char *buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); +} + +} // namespace + +TEST(WasmBpfTest, Module) { + auto module = createModule(); + ASSERT_TRUE(module); + // Test whether functions are exported + EXPECT_EQ(module->getFuncExportNum(), 6U); + EXPECT_NE(module->findFuncExports("wasm_load_bpf_object"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_close_bpf_object"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_attach_bpf_program"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_bpf_buffer_poll"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_bpf_map_fd_by_name"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_bpf_map_operate"), nullptr); +} + +static const size_t TASK_COMM_LEN = 16; +static const size_t MAX_FILENAME_LEN = 127; +struct event { + int pid; + int ppid; + unsigned exit_code; + unsigned long long duration_ns; + char comm[TASK_COMM_LEN]; + char filename[MAX_FILENAME_LEN]; + char exit_event; +}; + +class PollCallbackFunction + : public WasmEdge::Runtime::HostFunction { +public: + PollCallbackFunction() {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + uint32_t __attribute__((unused)) ctx, + uint32_t data, uint32_t data_sz) { + using namespace std; + using WasmEdge::unlikely; + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + if (data_sz < static_cast(sizeof(event))) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + const event *dataPtr = memory->getSpan(data, 1).data(); + if (unlikely(!dataPtr)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + auto nowTime = chrono::system_clock::now(); + if (dataPtr->exit_event == 1) { + fmt::print("{:%H:%M:%S} EXIT {:<16} {:<7} {:<7} [{}]"sv, nowTime, + dataPtr->comm, dataPtr->pid, dataPtr->ppid, + dataPtr->exit_code); + if (dataPtr->duration_ns != 0) { + fmt::print(" ({})"sv, dataPtr->duration_ns / 1000000); + } + fmt::print("\n"sv); + } else { + fmt::print("{:%H:%M:%S} EXEC {:<16} {:<7} {:<7} {}\n"sv, nowTime, + dataPtr->comm, dataPtr->pid, dataPtr->ppid, dataPtr->filename); + } + return 0; + } +}; + +TEST(WasmBpfTest, RunBpfProgramWithPolling) { + using namespace std::literals::string_view_literals; + // Test loading and attaching a BPF program and polling a buffer. + auto module = createModule(); + ASSERT_TRUE(module); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + EXPECT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "bootstrap.bpf.o"; + + // Ensure the BPF object we need exists. + EXPECT_TRUE(fs::exists(bpfObject)); + + // Read the BPF object into Wasm memory. + std::ifstream bpfObjStream(bpfObject); + EXPECT_TRUE(bpfObjStream.is_open()); + EXPECT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + EXPECT_FALSE(bpfObjectBytes.empty()); + + // Fill memory with the BPF object. + const uint32_t bpfObjectMemoryOffset = 1; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + + // Write `handle_exec`, the BPF function name, to memory. + const uint32_t targetHandleExecNameMemoryOffset = + bpfObjectMemoryOffset + static_cast(bpfObjectBytes.size()); + const std::string targetHandleExecName("handle_exec"); + // Zero terminated.. + std::vector targetHandleExecNameBytes(targetHandleExecName.size() + 1, + 0); + std::copy(targetHandleExecName.begin(), targetHandleExecName.end(), + targetHandleExecNameBytes.begin()); + fillMemContent(memoryInstRef, targetHandleExecNameMemoryOffset, + targetHandleExecNameBytes); + + // Write `handle_exit`, the BPF function name, to memory. + const uint32_t targetHandleExitNameMemoryOffset = + targetHandleExecNameMemoryOffset + + static_cast(targetHandleExecNameBytes.size()); + const std::string targetHandleExitName("handle_exit"); + // Zero terminated.. + std::vector targetHandleExitNameBytes(targetHandleExitName.size() + 1, + 0); + std::copy(targetHandleExitName.begin(), targetHandleExitName.end(), + targetHandleExitNameBytes.begin()); + fillMemContent(memoryInstRef, targetHandleExitNameMemoryOffset, + targetHandleExitNameBytes); + + // Fill the map name `rb` + const uint32_t mapNameMemoryOffset = + targetHandleExitNameMemoryOffset + + static_cast(targetHandleExitNameBytes.size()); + const std::string mapName("rb"); + // Zero terminated.. + std::vector mapNameBytes(mapName.size() + 1, 0); + std::copy(mapName.begin(), mapName.end(), mapNameBytes.begin()); + fillMemContent(memoryInstRef, mapNameMemoryOffset, mapNameBytes); + + // Prepare a memory area for storing polled items. + const uint32_t bufferPollMemoryOffset = + mapNameMemoryOffset + static_cast(mapNameBytes.size()); + const uint32_t bufferPollSize = 1024; + fillMemContent(memoryInstRef, bufferPollMemoryOffset, bufferPollSize, 0); + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + EXPECT_NE(loadFunc, nullptr); + EXPECT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + EXPECT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + EXPECT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + EXPECT_NE(attachFunc, nullptr); + EXPECT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + EXPECT_TRUE(attachFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(targetHandleExecNameMemoryOffset), + // There should be '\0' + WasmEdge::ValVariant( + targetHandleExecNameMemoryOffset + + static_cast(targetHandleExecName.size())), + }, + attachResult)); + EXPECT_GE(attachResult[0].get(), 0); + EXPECT_TRUE(attachFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(targetHandleExitNameMemoryOffset), + // There should be '\0' + WasmEdge::ValVariant( + targetHandleExitNameMemoryOffset + + static_cast(targetHandleExitName.size())), + }, + attachResult)); + EXPECT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + EXPECT_NE(mapFdFunc, nullptr); + EXPECT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + EXPECT_TRUE(mapFdFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(mapNameMemoryOffset)}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + EXPECT_GE(mapFd, 0); + + // In the following steps we prepare for polling. + // Create an instance of the polling callback function. + moduleInst.addHostFunc("__polling_callback_hostfunc"sv, + std::make_unique()); + auto *callbackFuncInst = + moduleInst.findFuncExports("__polling_callback_hostfunc"); + // Create a function table and fill it with the callback function. + auto funcTableInst = + std::make_unique( + WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); + EXPECT_TRUE(funcTableInst->setRefs( + std::initializer_list{callbackFuncInst}, 0, 0, + 1)); + // Add the table to the main module + moduleInst.addHostTable("__indirect_function_table"sv, + std::move(funcTableInst)); + + // Get the "wasm_bpf_buffer_poll" function + auto *bufferPollFunc = module->findFuncExports("wasm_bpf_buffer_poll"); + EXPECT_NE(bufferPollFunc, nullptr); + EXPECT_TRUE(bufferPollFunc->isHostFunction()); + auto &bufferPollFuncHost = dynamic_cast( + bufferPollFunc->getHostFunc()); + + // Call the polling function + std::array pollResult; + for (size_t i = 1; i <= 50; i++) { + using namespace std; + EXPECT_TRUE(bufferPollFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), // object handle + WasmEdge::ValVariant(mapFd), // map fd + UINT32_C(0), // callback function index + UINT32_C(0), // Custom context pointer + WasmEdge::ValVariant(bufferPollMemoryOffset), // buffer offset + WasmEdge::ValVariant(bufferPollSize), // buffer size + UINT32_C(100) // timeout (ms) + }, + pollResult)); + EXPECT_GE(pollResult[0].get(), 0); + } + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + EXPECT_NE(closeFunc, nullptr); + EXPECT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + EXPECT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + EXPECT_EQ(closeResult[0].get(), 0); +} + +static const size_t MAX_SLOTS = 26; + +struct hist { + unsigned int slots[MAX_SLOTS]; + char comm[TASK_COMM_LEN]; +} __attribute__((packed)); + +TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { + // Test loading and attaching a BPF program and polling a buffer. + auto module = createModule(); + ASSERT_TRUE(module); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + EXPECT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "runqlat.bpf.o"; + + // Ensure the BPF object we need exists. + EXPECT_TRUE(fs::exists(bpfObject)); + + // Read the BPF object into Wasm memory. + std::ifstream bpfObjStream(bpfObject); + EXPECT_TRUE(bpfObjStream.is_open()); + EXPECT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + EXPECT_FALSE(bpfObjectBytes.empty()); + // Offset used to place data in memory. + uint32_t nextOffset = 1; + + // Put the BPF object in memory. + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Write the strings to memory. + std::array strings = { + "hists", // Map name + "sched_wakeup", "sched_wakeup_new", "sched_switch", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + EXPECT_NE(loadFunc, nullptr); + EXPECT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + EXPECT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + EXPECT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + EXPECT_NE(attachFunc, nullptr); + EXPECT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + std::array programNameIndexes = {1, 2, 3}; + + // Attach the programs + for (size_t index : programNameIndexes) { + std::array attachResult; + EXPECT_TRUE( + attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[index]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[4]), + }, + attachResult)); + EXPECT_GE(attachResult[0].get(), 0); + } + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + EXPECT_NE(mapFdFunc, nullptr); + EXPECT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + EXPECT_TRUE(mapFdFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto histsFd = mapFdResult[0].get(); + EXPECT_GE(histsFd, 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapOptFunc = module->findFuncExports("wasm_bpf_map_operate"); + EXPECT_NE(mapOptFunc, nullptr); + EXPECT_TRUE(mapOptFunc->isHostFunction()); + auto &mapOptFuncHost = + dynamic_cast(mapOptFunc->getHostFunc()); + // A wrapper to call wasm_bpf_map_operate + auto callMapOperate = [&](int32_t fd, int32_t cmd, uint32_t key, + uint32_t value, uint32_t nextKey, + uint64_t flags) -> int32_t { + std::array callResult; + EXPECT_TRUE(mapOptFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(fd), WasmEdge::ValVariant(cmd), + WasmEdge::ValVariant(key), WasmEdge::ValVariant(value), + WasmEdge::ValVariant(nextKey), WasmEdge::ValVariant(flags)}, + callResult)); + return callResult[0].get(); + }; + // Three helper functions used below. + auto mapGetNextKey = [&](int32_t fd, uint32_t lookupKey, + uint32_t nextKey) -> int32_t { + // lookupKey is the last element -> returns -1 + // lookupKey found -> returns 0, set nextKey + // lookupKey not found -> returns 0, set nextKey to the first key + + return callMapOperate(fd, + 4, // BPF_MAP_GET_NEXT_KEY + lookupKey, 0, nextKey, 0); + }; + auto mapLookupElem = [&](int32_t fd, uint32_t key, + uint32_t valueOut) -> int32_t { + // key found -> returns 0 + // key not found -> returns -1 + return callMapOperate(fd, + 1, // BPF_MAP_LOOKUP_ELEM + key, valueOut, 0, 0); + }; + auto mapDeleteElem = [&](int32_t fd, uint32_t key) -> int32_t { + // key found -> return 0 + // key not found -> returns -1 + return callMapOperate(fd, + 3, // BPF_MAP_DELETE_ELEM + key, 0, 0, 0); + }; + // Three helper functions that make reading and writing more convenient. + auto readU32 = [&](uint32_t offset) -> uint32_t { + const auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + return *ptr; + }; + auto writeU32 = [&](uint32_t offset, uint32_t val) { + auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + auto readHistRef = [&](uint32_t offset) -> const hist & { + const auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + return *ptr; + }; + const uint32_t lookUpKeyOffset = nextOffset; + nextOffset += sizeof(uint32_t); + const uint32_t nextKeyOffset = nextOffset; + nextOffset += sizeof(uint32_t); + const uint32_t histOffset = nextOffset; + nextOffset += sizeof(hist); + + // Poll 10 times, with interval 1s + for (size_t i = 1; i <= 10; i++) { + using namespace std; + std::this_thread::sleep_for(std::chrono::seconds(1)); + writeU32(lookUpKeyOffset, static_cast(-2)); + while (mapGetNextKey(histsFd, lookUpKeyOffset, nextKeyOffset) == 0) { + EXPECT_GE(mapLookupElem(histsFd, nextKeyOffset, histOffset), 0); + const auto &histRef = readHistRef(histOffset); + size_t maxIdx = 0; + for (size_t i = 0; i < std::size(histRef.slots); i++) + if (histRef.slots[i] > 0) + maxIdx = i; + for (size_t i = 0; i < maxIdx; i++) { + auto low = UINT64_C(1) << (i); + auto high = (UINT64_C(1) << (i + 1)) - 1; + fmt::print("{:<6}...{:<6} {:<6}\n"sv, low, high, histRef.slots[i]); + } + writeU32(lookUpKeyOffset, readU32(nextKeyOffset)); + } + writeU32(lookUpKeyOffset, static_cast(-2)); + while (mapGetNextKey(histsFd, lookUpKeyOffset, nextKeyOffset) == 0) { + EXPECT_GE(mapDeleteElem(histsFd, nextKeyOffset), 0); + writeU32(lookUpKeyOffset, readU32(nextKeyOffset)); + } + fmt::print("\n"sv); + } + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + EXPECT_NE(closeFunc, nullptr); + EXPECT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + EXPECT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + EXPECT_EQ(closeResult[0].get(), 0); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt new file mode 100644 index 00000000..1f580014 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeFFmpegTests + main.cpp + + avcodec/avcodec_func.cpp + avcodec/avCodec.cpp + avcodec/avCodecParameters.cpp + avcodec/avPacket.cpp + avcodec/avCodecCtx.cpp + + avfilter/avfilter_func.cpp + avfilter/avfilter.cpp + + avformat/avformat_func.cpp + avformat/avformatContext.cpp + avformat/avInputOutputContext.cpp + avformat/avStream.cpp + avformat/avChapter.cpp + + avutil/avRational.cpp + avutil/avDictionary.cpp + avutil/avFrame.cpp + avutil/avutil_func.cpp + avutil/avError.cpp + avutil/avSampleFmt.cpp + avutil/avPixfmt.cpp + + swresample/swresample_func.cpp + + swscale/swscale_func.cpp + + utils.cpp +) + +# Downloading a sample file +execute_process( + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/ffmpeg/download-ffmpeg-sample-video.sh ${CMAKE_CURRENT_BINARY_DIR}/ffmpeg-assets + RESULT_VARIABLE DOWNLOAD_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +add_dependencies(wasmedgeFFmpegTests + wasmedgePluginWasmEdgeFFmpeg +) + +target_include_directories(wasmedgeFFmpegTests + PUBLIC + $ + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgeFFmpegTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeFFmpegTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeFFmpegTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeFFmpegTests wasmedgeFFmpegTests) diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp new file mode 100644 index 00000000..17dc5a49 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avcodec/avCodec.h" +#include "avcodec/module.h" +#include "utils.h" + +#include + +// Testing all AVCodecstruct + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVCodec) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t AVCodecPtr = UINT32_C(20); + uint32_t StringPtr = UINT32_C(68); + uint32_t NumeratorPtr = UINT32_C(72); + uint32_t DenominatorPtr = UINT32_C(76); + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + spdlog::info("Init FFmpeg Structs"sv); + initFFmpegStructs(AVCodecPtr, UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), UINT32_C(72)); + + uint32_t AVCodecId = readUInt32(MemInst, AVCodecPtr); + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecID = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecId"sv); + { + EXPECT_TRUE(HostFuncAVCodecID.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 27); // H264 + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecType = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecType"sv); + { + EXPECT_TRUE(HostFuncAVCodecType.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), + 0); // MediaType is Video + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_max_lowres"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecMaxLowres = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecMaxLowres"sv); + { + EXPECT_TRUE(HostFuncAVCodecMaxLowres.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_capabilities"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCapabilities = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCapabilities &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecCapabilities"sv); + { + EXPECT_TRUE(HostFuncAVCodecCapabilities.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_TRUE(Result[0].get() > 0); + } + + int32_t Length = 0; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_get_name_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetNameLen = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecGetNameLen &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetNameLen"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetNameLen.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_get_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetName = + dynamic_cast( + FuncInst->getHostFunc()); + + // Fill the Memory with 0. + fillMemContent(MemInst, StringPtr, Length); + spdlog::info("Testing AVCodecGetName"sv); + { + EXPECT_TRUE( + HostFuncAVCodecGetName.run(CallFrame, + std::initializer_list{ + AVCodecId, StringPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_get_long_name_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetLongNameLen = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecGetLongNameLen &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetLongNameLen"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetLongNameLen.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_get_long_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetLongName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecGetLongName &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetLongName"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetLongName.run( + CallFrame, + std::initializer_list{AVCodecId, StringPtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_profiles"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecProfiles = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecProfiles"sv); + { + EXPECT_TRUE(HostFuncAVCodecProfiles.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecPixFmtIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecPixFmtsIsNull &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecPixFmtsIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecPixFmtIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecPixFmtIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecPixFmtsIter &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecPixFmtsIter"sv); + { + uint32_t Idx = 0; + EXPECT_TRUE(HostFuncAVCodecPixFmtIter.run( + CallFrame, std::initializer_list{AVCodecId, Idx}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedFrameratesIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedFrameratesIsNull + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedFramratesIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedFrameratesIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedFrameratesIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedFrameratesIter + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedFrameratesIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedFrameratesIter.run( + CallFrame, + std::initializer_list{AVCodecId, 1, NumeratorPtr, + DenominatorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedSampleRatesIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedSampleRatesIsNull + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedSampleRatesIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedSampleRatesIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedSampleRatesIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedSampleRatesIter + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedSampleRatesIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedSampleRatesIter.run( + CallFrame, std::initializer_list{AVCodecId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecChannelLayoutIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecChannelLayoutIsNull &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecChannelLayoutIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecChannelLayoutIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecChannelLayoutIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecChannelLayoutIter &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecChannelLayoutIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecChannelLayoutIter.run( + CallFrame, std::initializer_list{AVCodecId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSampleFmtsIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSampleFmtsIsNull &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSampleFmtsIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecSampleFmtsIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSampleFmtsIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSampleFmtsIter &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSampleFmtsIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecSampleFmtsIter.run( + CallFrame, std::initializer_list{AVCodecId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp new file mode 100644 index 00000000..3407f4e7 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp @@ -0,0 +1,1643 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avcodec/avCodecContext.h" +#include "avcodec/module.h" + +#include "utils.h" + +#include + +// Testing all AVCodecCtxstruct +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVCodecCtx) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t AVCodecCtxPtr = UINT32_C(64); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), AVCodecCtxPtr, UINT32_C(68), UINT32_C(72)); + uint32_t NumPtr = UINT32_C(76); + uint32_t DenPtr = UINT32_C(80); + uint32_t AVCodecPtr = UINT32_C(84); + + uint32_t AVCodecCtxId = readUInt32(MemInst, AVCodecCtxPtr); + + auto *FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxCodecID = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxCodecID &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxCodecID.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 27); // H264 + } + + int32_t CodecType = 0; // MediaType Video + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_codec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetCodecType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetCodecType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetCodecType.run( + CallFrame, + std::initializer_list{AVCodecCtxId, CodecType}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxCodecType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxCodecType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxCodecType.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), CodecType); // MediaType Video + } + + int32_t Num = 5; + int32_t Den = 10; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_time_base"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetTimebase &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetTimebase.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Num, Den}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_time_base"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxTimeBase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxTimeBase &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxTimeBase.run( + CallFrame, + std::initializer_list{AVCodecCtxId, NumPtr, + DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + int32_t Numerator = readSInt32(MemInst, NumPtr); + int32_t Denominator = readSInt32(MemInst, DenPtr); + EXPECT_EQ(Numerator, Num); + EXPECT_EQ(Denominator, Den); + } + + int32_t Dimension = 200; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_width"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetWidth = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetWidth &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetWidth.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Dimension}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_width"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxWidth = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxWidth.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), Dimension); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_height"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetHeight = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetHeight &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetHeight.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Dimension}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_height"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxHeight = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxHeight.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), Dimension); + } + + Num = 10; + Den = 20; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_aspect_ratio"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSampleAspectRatio = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSampleAspectRatio + &>(FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSampleAspectRatio.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Num, Den}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_sample_aspect_ratio"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSampleAspectRatio = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSampleAspectRatio &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSampleAspectRatio.run( + CallFrame, + std::initializer_list{AVCodecCtxId, NumPtr, + DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + int32_t Numerator = readSInt32(MemInst, NumPtr); + int32_t Denominator = readSInt32(MemInst, DenPtr); + EXPECT_EQ(Numerator, Num); + EXPECT_EQ(Denominator, Den); + } + + uint64_t ChannelLayoutId = 1; // FRONT_LEFT; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_channel_layout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetChannelLayout &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetChannelLayout.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + ChannelLayoutId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_channel_layout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxChannelLayout &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxChannelLayout.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ChannelLayoutId); + } + + uint32_t PixFormatId = 1; // YUV420P + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_pix_fmt"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetPixFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetPixFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetPixFormat.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PixFormatId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_pix_fmt"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxPixFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxPixFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxPixFormat.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), PixFormatId); + } + + uint32_t SampleFmtId = 1; // SAMPLE_FMT_U8 + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_format"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSampleFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSampleFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSampleFormat.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SampleFmtId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_sample_format"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSampleFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSampleFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSampleFormat.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), SampleFmtId); + } + + int32_t SampleRate = 500; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSampleRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSampleRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSampleRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SampleRate}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_sample_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSampleRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSampleRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSampleRate.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), SampleRate); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_gop_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetGopSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetGopSize &>( + FuncInst->getHostFunc()); + + { + int32_t GopSize = 20; + EXPECT_TRUE(HostFuncAVCodecCtxSetGopSize.run( + CallFrame, + std::initializer_list{AVCodecCtxId, GopSize}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_max_b_frames"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMaxBFrames = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMaxBFrames &>( + FuncInst->getHostFunc()); + + { + int32_t MaxBFrames = 30; + EXPECT_TRUE(HostFuncAVCodecCtxSetMaxBFrames.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MaxBFrames}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_factor"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBQuantFactor = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBQuantFactor &>( + FuncInst->getHostFunc()); + + { + float BQuantFactor = 12.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetBQuantFactor.run( + CallFrame, + std::initializer_list{AVCodecCtxId, BQuantFactor}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_offset"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBQuantOffset = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBQuantOffset &>( + FuncInst->getHostFunc()); + + { + float BQuantOffset = 3.53; + EXPECT_TRUE(HostFuncAVCodecCtxSetBQuantOffset.run( + CallFrame, + std::initializer_list{AVCodecCtxId, BQuantOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_factor"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetIQuantFactor = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIQuantFactor &>( + FuncInst->getHostFunc()); + + { + float IQuantFactor = 3.435; + EXPECT_TRUE(HostFuncAVCodecCtxSetIQuantFactor.run( + CallFrame, + std::initializer_list{AVCodecCtxId, IQuantFactor}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_offset"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetIQuantOffset = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIQuantOffset &>( + FuncInst->getHostFunc()); + + { + float IQuantOffset = 6.322; + EXPECT_TRUE(HostFuncAVCodecCtxSetIQuantOffset.run( + CallFrame, + std::initializer_list{AVCodecCtxId, IQuantOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_lumi_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetLumiMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetLumiMasking &>( + FuncInst->getHostFunc()); + + { + float LumiMasking = 54.32432; + EXPECT_TRUE(HostFuncAVCodecCtxSetLumiMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, LumiMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_temporal_cplx_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetTemporalCplxMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetTemporalCplxMasking + &>(FuncInst->getHostFunc()); + + { + float TemporialCplxMasking = 642.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetTemporalCplxMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + TemporialCplxMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_spatial_cplx_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSpatialCplxMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSpatialCplxMasking + &>(FuncInst->getHostFunc()); + + { + float SpatialCplxMasking = 324.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetSpatialCplxMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + SpatialCplxMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_p_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetPMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetPMasking &>( + FuncInst->getHostFunc()); + + { + float PMasking = 65.3245; + EXPECT_TRUE(HostFuncAVCodecCtxSetPMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_dark_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetDarkMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetDarkMasking &>( + FuncInst->getHostFunc()); + + { + float DarkMasking = 83.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetDarkMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DarkMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MeCmp = 532; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MeCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_sub_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeSubCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeSubCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MeSubCmp = 321; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeSubCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MeSubCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MbCmp = 243; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_ildct_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetIldctCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIldctCmp &>( + FuncInst->getHostFunc()); + + { + int32_t IldctCmp = 3; + EXPECT_TRUE(HostFuncAVCodecCtxSetIldctCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, IldctCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_dia_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetDiaSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetDiaSize &>( + FuncInst->getHostFunc()); + + { + int32_t DiaSize = 9; + EXPECT_TRUE(HostFuncAVCodecCtxSetDiaSize.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiaSize}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_last_predictor_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetLastPredictorsCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetLastPredictorsCount + &>(FuncInst->getHostFunc()); + + { + int32_t LastPredictorCount = 21; + EXPECT_TRUE(HostFuncAVCodecCtxSetLastPredictorsCount.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + LastPredictorCount}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_pre_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMePreCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMePreCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MePreCmp = 53; + EXPECT_TRUE(HostFuncAVCodecCtxSetMePreCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MePreCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_pre_dia_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetPreDiaSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetPreDiaSize &>( + FuncInst->getHostFunc()); + + { + int32_t PreDiaSize = 74; + EXPECT_TRUE(HostFuncAVCodecCtxSetPreDiaSize.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PreDiaSize}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_subpel_quality"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeSubpelQuality = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeSubpelQuality &>( + FuncInst->getHostFunc()); + + { + int32_t MeSubpelQuality = 85; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeSubpelQuality.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + MeSubpelQuality}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_range"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeRange &>( + FuncInst->getHostFunc()); + + { + int32_t SetMeRange = 31; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeRange.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SetMeRange}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_decision"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbDecision = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbDecision &>( + FuncInst->getHostFunc()); + + { + int32_t MbDecision = 78; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbDecision.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbDecision}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmin"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbLMin = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbLMin &>( + FuncInst->getHostFunc()); + + { + int32_t MbLMin = 11; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbLMin.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbLMin}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmax"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbLMax = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbLMax &>( + FuncInst->getHostFunc()); + + { + int32_t MbLMax = 18; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbLMax.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbLMax}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_intra_dc_precision"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + int32_t IntraDcPrecision = 323; + auto &HostFuncAVCodecCtxSetIntraDcPrecision = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIntraDcPrecision &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetIntraDcPrecision.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + IntraDcPrecision}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_intra_dc_precision"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxIntraDcPrecision = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxIntraDcPrecision &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxIntraDcPrecision.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), IntraDcPrecision); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmin"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetQMin = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetQMin &>( + FuncInst->getHostFunc()); + + { + int32_t QMin = 10; + EXPECT_TRUE(HostFuncAVCodecCtxSetQMin.run( + CallFrame, + std::initializer_list{AVCodecCtxId, QMin}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmax"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetQMax = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetQMax &>( + FuncInst->getHostFunc()); + + { + int32_t QMax = 20; + EXPECT_TRUE(HostFuncAVCodecCtxSetQMax.run( + CallFrame, + std::initializer_list{AVCodecCtxId, QMax}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_global_quality"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetGlobalQuality = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetGlobalQuality &>( + FuncInst->getHostFunc()); + + { + int32_t GlobalQuality = 93; + EXPECT_TRUE(HostFuncAVCodecCtxSetGlobalQuality.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + GlobalQuality}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_colorspace"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetColorspace = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetColorspace &>( + FuncInst->getHostFunc()); + + int32_t ColorspaceId = 1; // AVCOL_SPC_BT709 + { + EXPECT_TRUE(HostFuncAVCodecCtxSetColorspace.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ColorspaceId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_colorspace"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorspace = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorspace &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorspace.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ColorspaceId); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_color_range"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetColorRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetColorRange &>( + FuncInst->getHostFunc()); + + int32_t ColorRangeId = 1; // MPEG + { + EXPECT_TRUE(HostFuncAVCodecCtxSetColorRange.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ColorRangeId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_color_range"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorRange &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorRange.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ColorRangeId); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_frame_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxFrameSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxFrameSize &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxFrameSize.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBitRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBitRate &>( + FuncInst->getHostFunc()); + + int64_t BitRate = 9932; + { + EXPECT_TRUE(HostFuncAVCodecCtxSetBitRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, BitRate}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_bit_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxBitRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxBitRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxBitRate.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), BitRate); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_rc_max_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + int64_t RcMaxRate = 3245; + auto &HostFuncAVCodecCtxSetRcMaxRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetRcMaxRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetRcMaxRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, RcMaxRate}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_rc_max_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxRcMaxRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxRcMaxRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxRcMaxRate.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), RcMaxRate); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate_tolerance"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBitRateTolerance = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBitRateTolerance &>( + FuncInst->getHostFunc()); + + { + int32_t BitRateTolerance = 9543; + EXPECT_TRUE(HostFuncAVCodecCtxSetBitRateTolerance.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + BitRateTolerance}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_compression_level"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetCompressionLevel = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetCompressionLevel &>( + FuncInst->getHostFunc()); + + { + int32_t CompressionLevel = 934; + EXPECT_TRUE(HostFuncAVCodecCtxSetCompressionLevel.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + CompressionLevel}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_framerate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + Num = 20; + Den = 30; + auto &HostFuncAVCodecCtxSetFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetFrameRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetFrameRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Num, Den}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_framerate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxFrameRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxFrameRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, NumPtr, + DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + int32_t Numerator = readSInt32(MemInst, NumPtr); + int32_t Denominator = readSInt32(MemInst, DenPtr); + EXPECT_EQ(Numerator, Num); + EXPECT_EQ(Denominator, Den); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetFlags = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetFlags &>( + FuncInst->getHostFunc()); + + { + int32_t Flags = 3; + EXPECT_TRUE(HostFuncAVCodecCtxSetFlags.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_strict_std_compliance"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetStrictStdCompliance = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetStrictStdCompliance + &>(FuncInst->getHostFunc()); + + { + int32_t ComplianceId = 3; + EXPECT_TRUE(HostFuncAVCodecCtxSetStrictStdCompliance.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ComplianceId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_debug"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetDebug = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetDebug &>( + FuncInst->getHostFunc()); + + { + int32_t Debug = 50; + EXPECT_TRUE(HostFuncAVCodecCtxSetDebug.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Debug}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxCodec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxCodec.run( + CallFrame, + std::initializer_list{AVCodecCtxId, AVCodecPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, AVCodecPtr) > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_channels"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetChannels &>( + FuncInst->getHostFunc()); + + int32_t Channels = 10; + { + EXPECT_TRUE(HostFuncAVCodecCtxSetChannels.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Channels}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_channels"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxChannels &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxChannels.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), Channels); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_loop_filter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipLoopFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipLoopFilter &>( + FuncInst->getHostFunc()); + + int32_t DiscardId = 16; // Bidirectional + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipLoopFilter.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiscardId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipFrame &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipFrame.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiscardId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_idct"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipIdct = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipIdct &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipIdct.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiscardId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_error_concealment"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetErrorConcealment = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetErrorConcealment &>( + FuncInst->getHostFunc()); + + { + int32_t ErrorConcealment = 99; + EXPECT_TRUE(HostFuncAVCodecCtxSetErrorConcealment.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + ErrorConcealment}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_err_recognition"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetErrorRecognition = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetErrorRecognition &>( + FuncInst->getHostFunc()); + + { + int32_t ErrorRecognition = 88; + EXPECT_TRUE(HostFuncAVCodecCtxSetErrorRecognition.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + ErrorRecognition}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_delay"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxDelay = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxDelay.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_top"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipTop = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipTop &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 50; + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipTop.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_bottom"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipBottom = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipBottom &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 60; + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipBottom.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_refs"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxRefs = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxRefs.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 4); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSliceFlags = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSliceFlags &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 70; + EXPECT_TRUE(HostFuncAVCodecCtxSetSliceFlags.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSliceCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSliceCount &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 100; + EXPECT_TRUE(HostFuncAVCodecCtxSetSliceCount.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_field_order"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetFieldOrder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetFieldOrder &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 200; + EXPECT_TRUE(HostFuncAVCodecCtxSetFieldOrder.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_color_trc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorTrc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorTrc &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorTrc.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_chroma_sample_location"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxChromaSampleLocation = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxChromaSampleLocation + &>(FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxChromaSampleLocation.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_frame_number"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxFrameNumber = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxFrameNumber &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxFrameNumber.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_block_align"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxBlockAlign = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxBlockAlign &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxBlockAlign.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_sample_fmt"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetRequestSampleFmt = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetRequestSampleFmt &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetRequestSampleFmt.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SampleFmtId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_audio_service_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxAudioServiceType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxAudioServiceType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxAudioServiceType.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_has_b_frames"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxHasBFrames = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxHasBFrames &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxHasBFrames.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_active_thread_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxActiveThreadType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxActiveThreadType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxActiveThreadType.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetThreadType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetThreadType &>( + FuncInst->getHostFunc()); + + { + int32_t ThreadType = 1; // Frame + EXPECT_TRUE(HostFuncAVCodecCtxSetThreadType.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ThreadType}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetThreadCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetThreadCount &>( + FuncInst->getHostFunc()); + + int32_t ThreadCount = 50; + { + EXPECT_TRUE(HostFuncAVCodecCtxSetThreadCount.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ThreadCount}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_thread_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxThreadCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxThreadCount &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxThreadCount.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ThreadCount); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_color_primaries"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorPrimaries = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorPrimaries &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorPrimaries.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp new file mode 100644 index 00000000..2bcfdfd4 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avcodec/avCodecParameters.h" +#include "avcodec/module.h" + +#include "utils.h" + +#include + +// Testing all AVCodecstruct + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVCodecParameters) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t AVCodecParamPtr = UINT32_C(60); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + AVCodecParamPtr, UINT32_C(64), UINT32_C(68), UINT32_C(72)); + + uint32_t AVCodecParamId = readUInt32(MemInst, AVCodecParamPtr); + + auto *FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodecparam_codec_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParamCodecId = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParamCodecId &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecParamCodecId.run( + CallFrame, std::initializer_list{AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), 27); // H264 + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodecparam_codec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParamCodecType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParamCodecType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecParamCodecType.run( + CallFrame, std::initializer_list{AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); // MediaType Video + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodecparam_set_codec_tag"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParamSetCodecTag = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParamSetCodecTag &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecParamSetCodecTag.run( + CallFrame, + std::initializer_list{AVCodecParamId, 20}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp new file mode 100644 index 00000000..fa163e49 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avcodec/avPacket.h" +#include "avcodec/module.h" + +#include "utils.h" + +#include + +// Testing all AVPacket + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVPacketTest) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t PacketPtr = UINT32_C(4); + uint32_t PacketPtr2 = UINT32_C(8); + uint32_t DataPtr = UINT32_C(12); + + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketAlloc.run( + CallFrame, std::initializer_list{PacketPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVPacketAlloc.run( + CallFrame, std::initializer_list{PacketPtr2}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t PacketId = readUInt32(MemInst, PacketPtr); + uint32_t PacketId2 = readUInt32(MemInst, PacketPtr2); + ASSERT_TRUE(PacketId > 0); + ASSERT_TRUE(PacketId2 > 0); + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_new_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVNewPacket = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t Size = 40; + EXPECT_TRUE(HostFuncAVNewPacket.run( + CallFrame, std::initializer_list{PacketId, Size}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_grow_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVGrowPacket = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t Size = 40; + EXPECT_TRUE(HostFuncAVGrowPacket.run( + CallFrame, std::initializer_list{PacketId, Size}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_shrink_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVShrinkPacket = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t Size = 40; + EXPECT_TRUE(HostFuncAVShrinkPacket.run( + CallFrame, std::initializer_list{PacketId, Size}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + uint32_t StreamIdx = 3; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_set_stream_index"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetStreamIndex = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketSetStreamIndex &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetStreamIndex.run( + CallFrame, + std::initializer_list{PacketId, StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_stream_index"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketStreamIndex = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketStreamIndex &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketStreamIndex.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), StreamIdx); + } + + uint32_t Size = 0; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSize = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSize.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + Size = Result[0].get(); + EXPECT_TRUE(Size > 0); + } + + uint32_t Flags = 5; + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_set_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetFlags.run( + CallFrame, std::initializer_list{PacketId, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketFlags.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Flags); + } + + int64_t Pos = 500; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_set_pos"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetPos = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetPos.run( + CallFrame, std::initializer_list{PacketId, Pos}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_pos"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketPos = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketPos.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Pos); + } + + int64_t Duration = 100; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_set_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetDuration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketSetDuration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetDuration.run( + CallFrame, + std::initializer_list{PacketId, Duration}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketDuration = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketDuration.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Duration); + } + + int64_t Dts = 1000; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_set_dts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetDts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetDts.run( + CallFrame, std::initializer_list{PacketId, Dts}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_dts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketDts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketDts.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Dts); + } + + int64_t Pts = 5000; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_set_pts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetPts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetPts.run( + CallFrame, std::initializer_list{PacketId, Pts}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_pts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketPts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketPts.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Pts); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_is_data_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketIsDataNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketIsDataNull &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketIsDataNull.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_data"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketData = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketData.run( + CallFrame, + std::initializer_list{PacketId, DataPtr, Size}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_ref"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketRef = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketRef.run( + CallFrame, + std::initializer_list{PacketId2, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_unref"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketUnref = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketUnref.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp new file mode 100644 index 00000000..f82fb2d6 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -0,0 +1,577 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avcodec/avcodec_func.h" +#include "avcodec/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// TODO: Commented functions need tests. + +TEST_F(FFmpegTest, AVCodecFunc) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t CodecCtxPtr = UINT32_C(4); + uint32_t CodecParamPtr = UINT32_C(8); + uint32_t CodecParamPtr2 = UINT32_C(20); + uint32_t CodecDecoderPtr = UINT32_C(12); + uint32_t CodecEncoderPtr = UINT32_C(16); + uint32_t StrPtr = UINT32_C(32); + + uint32_t CodecNamePtr = UINT32_C(150); + std::string CodecName = "mpeg1video"; + spdlog::info("Filling memory CodecName into CodecNamePtr"sv); + fillMemContent(MemInst, CodecNamePtr, CodecName); + + uint32_t ID = 1; + + auto *FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_alloc_context3"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecAllocContext3 = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecAllocContext3 &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AvCodecAllocContext3"sv); + { + EXPECT_TRUE(HostFuncAVCodecAllocContext3.run( + CallFrame, std::initializer_list{0, CodecCtxPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecCtxId = readUInt32(MemInst, CodecCtxPtr); + ASSERT_TRUE(AVCodecCtxId > 0); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersAlloc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersAlloc &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersAlloc"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersAlloc.run( + CallFrame, std::initializer_list{CodecParamPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVCodecParametersAlloc.run( + CallFrame, std::initializer_list{CodecParamPtr2}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecParamId = readUInt32(MemInst, CodecParamPtr); + ASSERT_TRUE(AVCodecParamId > 0); + + uint32_t AVCodecParamId2 = readUInt32(MemInst, CodecParamPtr2); + ASSERT_TRUE(AVCodecParamId2 > 0); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_from_context"sv); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersFromContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersFromContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersFromContext"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersFromContext.run( + CallFrame, + std::initializer_list{AVCodecParamId, + AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_get_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetType = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetType"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetType.run( + CallFrame, std::initializer_list{ID}, Result)); + EXPECT_EQ(Result[0].get(), 0); // Video Type + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_decoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindDecoder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindDecoder &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindDecoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecFindDecoder.run( + CallFrame, + std::initializer_list{ID, CodecDecoderPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecDecoderId = readUInt32(MemInst, CodecDecoderPtr); + ASSERT_TRUE(AVCodecDecoderId > 0); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_encoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindEncoder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindEncoder &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindEncoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecFindEncoder.run( + CallFrame, + std::initializer_list{ID, CodecEncoderPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecEncoderId = readUInt32(MemInst, CodecEncoderPtr); + ASSERT_TRUE(AVCodecEncoderId > 0); + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_open2"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecOpen2 = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecOpen2"sv); + // Invalid argument passed. Return -22 Error code. Means functionality + // working. + { + EXPECT_TRUE( + HostFuncAVCodecOpen2.run(CallFrame, + std::initializer_list{ + AVCodecCtxId, AVCodecEncoderId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_codec_is_encoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecIsEncoder = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecIsEncoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecIsEncoder.run( + CallFrame, + std::initializer_list{AVCodecEncoderId}, Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_codec_is_decoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecIsDecoder = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecIsDecoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecIsDecoder.run( + CallFrame, + std::initializer_list{AVCodecDecoderId}, Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_decoder_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindDecoderByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindDecoderByName &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindDecoderByName"sv); + { + uint32_t Length = CodecName.length(); + EXPECT_TRUE(HostFuncAVCodecFindDecoderByName.run( + CallFrame, + std::initializer_list{CodecDecoderPtr, + CodecNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_encoder_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindEncoderByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindEncoderByName &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindEncoderByName"sv); + { + uint32_t Length = CodecName.length(); + EXPECT_TRUE(HostFuncAVCodecFindEncoderByName.run( + CallFrame, + std::initializer_list{CodecEncoderPtr, + CodecNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_to_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersToContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersToContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersToContext"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersToContext.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + // TODO: Need FormatCtxId to test this function. + // FuncInst = AVCodecMod->findFuncExports( + // "wasmedge_ffmpeg_avcodec_avcodec_parameters_copy"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // + // auto &HostFuncAVCodecParametersCopy = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersCopy &>( + // FuncInst->getHostFunc()); + // + // { + // EXPECT_TRUE(HostFuncAVCodecParametersCopy.run( + // CallFrame, std::initializer_list{}, Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecVersion"sv); + { + EXPECT_TRUE(HostFuncAVCodecVersion.run( + CallFrame, std::initializer_list{}, Result)); + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecConfigurationLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecConfigurationLength"sv); + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVCodecConfigurationLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecConfiguration &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecConfiguration"sv); + { + EXPECT_TRUE(HostFuncAVCodecConfiguration.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecLicenseLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecLicenseLength"sv); + { + EXPECT_TRUE(HostFuncAVCodecLicenseLength.run( + CallFrame, std::initializer_list{}, Result)); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecLicense"sv); + { + EXPECT_TRUE(HostFuncAVCodecLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_free_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFreeContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFreeContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFreeContext"sv); + { + EXPECT_TRUE(HostFuncAVCodecFreeContext.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersFree = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersFree &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersFree"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersFree.run( + CallFrame, std::initializer_list{AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +TEST_F(FFmpegTest, SendPacketReceiveFrame) { + std::string FileName = "ffmpeg-assets/dummy.mp4"; // 32 chars + uint32_t CodecCtxPtr = UINT32_C(64); + uint32_t FramePtr = UINT32_C(72); + uint32_t PacketPtr = UINT32_C(68); + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), CodecCtxPtr, PacketPtr, FramePtr); + + uint32_t FrameId = readUInt32(MemInst, FramePtr); + uint32_t PacketId = readUInt32(MemInst, PacketPtr); + uint32_t CodecCtxId = readUInt32(MemInst, CodecCtxPtr); + + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_send_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSendFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSendFrame"sv); + // Invalid Argument Error. Should Use Encoder, I'm using decoder + // Aim is to test the functionality. + { + EXPECT_TRUE(HostFuncAVCodecSendFrame.run( + CallFrame, + std::initializer_list{CodecCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + // Invalid Argument Error. Should Use Encoder, I'm using decoder + // Aim is to test the functionality. + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_receive_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecReceivePacket = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecReceivePacket &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecReceivePacket"sv); + { + EXPECT_TRUE(HostFuncAVCodecReceivePacket.run( + CallFrame, + std::initializer_list{CodecCtxId, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_send_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSendPacket = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSendPacket &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSendPacket"sv); + // Send packet to Decoder. + { + EXPECT_TRUE(HostFuncAVCodecSendPacket.run( + CallFrame, + std::initializer_list{CodecCtxId, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_receive_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + // Decoder Receives the Packet as Frame. + auto &HostFuncAVCodecReceiveFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecReceiveFrame &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecReceiveFrame"sv); + { + EXPECT_TRUE(HostFuncAVCodecReceiveFrame.run( + CallFrame, + std::initializer_list{CodecCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_rescale_ts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketRescaleTs = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketRescaleTs &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVPacketRescaleTs"sv); + { + int32_t SrcNum = 2; + int32_t SrcDen = 3; + int32_t DestNum = 5; + int32_t DestDen = 9; + EXPECT_TRUE(HostFuncAVPacketRescaleTs.run( + CallFrame, + std::initializer_list{PacketId, SrcNum, SrcDen, + DestNum, DestDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_make_writable"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketMakeWritable = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketMakeWritable &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVPacketMakeWritable"sv); + { + EXPECT_TRUE(HostFuncAVPacketMakeWritable.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_flush_buffers"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFlushBuffers = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFlushBuffers &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFlushBuffers"sv); + { + EXPECT_TRUE(HostFuncAVCodecFlushBuffers.run( + CallFrame, std::initializer_list{CodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_close"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecClose = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecClose"sv); + { + EXPECT_TRUE(HostFuncAVCodecClose.run( + CallFrame, std::initializer_list{CodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp new file mode 100644 index 00000000..c96448a2 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avfilter/avFilter.h" +#include "avfilter//avfilter_func.h" +#include "avfilter/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVFilterStructs) { + ASSERT_TRUE(AVFilterMod != nullptr); + + uint32_t FilterPtr = UINT32_C(8); + uint32_t InputFilterPadPtr = UINT32_C(12); + uint32_t OutputFilterPadPtr = UINT32_C(16); + uint32_t InputNamePtr = UINT32_C(100); + uint32_t StrPtr = UINT32_C(150); + + std::string InputName = std::string("abuffer"); + fillMemContent(MemInst, InputNamePtr, InputName); + + // ================================================================== + // Start Initialize AVFilter + // ================================================================== + + auto *FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetByName &>( + FuncInst->getHostFunc()); + + { + int32_t Length = InputName.length(); + EXPECT_TRUE(HostFuncAVFilterGetByName.run( + CallFrame, + std::initializer_list{FilterPtr, InputNamePtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t FilterId = readUInt32(MemInst, FilterPtr); + ASSERT_TRUE(FilterId > 0); + // ================================================================== + // End Initialize AVFilter + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterNameLength &>( + FuncInst->getHostFunc()); + + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVFilterNameLength.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterName = + dynamic_cast( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, StrPtr, Length); + { + EXPECT_TRUE(HostFuncAVFilterName.run( + CallFrame, + std::initializer_list{FilterId, StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_description_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterDescriptionLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterDescriptionLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterDescriptionLength.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_description"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterDescription = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterDescription &>( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, StrPtr, Length); + { + EXPECT_TRUE(HostFuncAVFilterDescription.run( + CallFrame, + std::initializer_list{FilterId, StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_nb_inputs"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterNbInputs = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterNbInputs &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterNbInputs.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_nb_outputs"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterNbOutputs = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterNbOutputs &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterNbOutputs.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterFlags.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_inputs_filter_pad"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetInputsFilterPad = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetInputsFilterPad &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGetInputsFilterPad.run( + CallFrame, + std::initializer_list{FilterId, + InputFilterPadPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_outputs_filter_pad"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetOutputsFilterPad = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetOutputsFilterPad &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGetOutputsFilterPad.run( + CallFrame, + std::initializer_list{FilterId, + OutputFilterPadPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t OutputFilterPadId = readUInt32(MemInst, OutputFilterPadPtr); + ASSERT_TRUE(OutputFilterPadId > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_get_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadGetNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterPadGetNameLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterPadGetNameLength.run( + CallFrame, + std::initializer_list{OutputFilterPadId, 0}, + Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_get_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadGetName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterPadGetName &>( + FuncInst->getHostFunc()); + + { + int32_t Idx = 0; + EXPECT_TRUE(HostFuncAVFilterPadGetName.run( + CallFrame, + std::initializer_list{OutputFilterPadId, Idx, + StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_get_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadGetType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterPadGetType &>( + FuncInst->getHostFunc()); + + { + int32_t Idx = 0; + EXPECT_TRUE(HostFuncAVFilterPadGetType.run( + CallFrame, + std::initializer_list{OutputFilterPadId, Idx}, + Result)); + EXPECT_EQ(Result[0].get(), 1); // Audio + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_drop"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadDrop = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterPadDrop.run( + CallFrame, + std::initializer_list{OutputFilterPadId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp new file mode 100644 index 00000000..1e50b401 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -0,0 +1,685 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avfilter//avfilter_func.h" +#include "avfilter/avFilter.h" +#include "avfilter/buffer_source_sink.h" +#include "avfilter/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVFilterFunc) { + ASSERT_TRUE(AVFilterMod != nullptr); + + // Structs Ptr + uint32_t FilterGraphPtr = UINT32_C(4); + uint32_t FilterPtr = UINT32_C(8); + uint32_t Filter2Ptr = UINT32_C(12); + uint32_t InputFilterCtxPtr = UINT32_C(28); // AVFilterContext + uint32_t OutputFilterCtxPtr = UINT32_C(24); // AVFilterContext + uint32_t InputInOutPtr = UINT32_C(32); + uint32_t OutputInOutPtr = UINT32_C(36); + uint32_t FramePtr = UINT32_C(40); + + // Strings. + uint32_t InputNamePtr = UINT32_C(100); + uint32_t OutputNamePtr = UINT32_C(150); + uint32_t InputFilterNamePtr = UINT32_C(200); + uint32_t OutputFilterNamePtr = UINT32_C(250); + uint32_t ArgsPtr = UINT32_C(300); + uint32_t SpecPtr = UINT32_C(450); + uint32_t StrPtr = UINT32_C(500); + + std::string InputName = std::string("abuffer"); + fillMemContent(MemInst, InputNamePtr, InputName); + + std::string OutputName = std::string("abuffersink"); + fillMemContent(MemInst, OutputNamePtr, OutputName); + + std::string InputFilterName = std::string("in"); + fillMemContent(MemInst, InputFilterNamePtr, InputFilterName); + + std::string OutputFilterName = std::string("out"); + fillMemContent(MemInst, OutputFilterNamePtr, OutputFilterName); + + std::string Args = std::string( + "time_base=1/44100:sample_rate=44100:sample_fmt=fltp:channel_layout=0x3"); + fillMemContent(MemInst, ArgsPtr, Args); + + std::string SpecStr = std::string("anull"); + fillMemContent(MemInst, SpecPtr, SpecStr); + + initEmptyFrame(FramePtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + + auto *FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphAlloc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphAlloc &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphAlloc.run( + CallFrame, std::initializer_list{FilterGraphPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t FilterGraphId = readUInt32(MemInst, FilterGraphPtr); + ASSERT_TRUE(FilterGraphId > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetByName &>( + FuncInst->getHostFunc()); + + { + int32_t Length = InputName.length(); + EXPECT_TRUE(HostFuncAVFilterGetByName.run( + CallFrame, + std::initializer_list{FilterPtr, InputNamePtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + Length = OutputName.length(); + EXPECT_TRUE(HostFuncAVFilterGetByName.run( + CallFrame, + std::initializer_list{Filter2Ptr, OutputNamePtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t FilterId = readUInt32(MemInst, FilterPtr); + uint32_t Filter2Id = readUInt32(MemInst, Filter2Ptr); + ASSERT_TRUE(FilterId > 0); + ASSERT_TRUE(Filter2Id > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_create_filter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphCreateFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphCreateFilter &>( + FuncInst->getHostFunc()); + + { + int32_t NameLen = InputFilterName.length(); + int32_t ArgsLen = Args.length(); + EXPECT_TRUE(HostFuncAVFilterGraphCreateFilter.run( + CallFrame, + std::initializer_list{ + InputFilterCtxPtr, FilterId, InputFilterNamePtr, NameLen, ArgsPtr, + ArgsLen, FilterGraphId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + writeUInt32(MemInst, 0, InputFilterCtxPtr); // Setting InputFilterCtx to 0 + + NameLen = OutputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterGraphCreateFilter.run( + CallFrame, + std::initializer_list{ + OutputFilterCtxPtr, Filter2Id, OutputFilterNamePtr, NameLen, 0, 0, + FilterGraphId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + writeUInt32(MemInst, 0, OutputFilterCtxPtr); // Setting OutputFilterCtx to 0 + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutAlloc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutAlloc &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutAlloc.run( + CallFrame, std::initializer_list{InputInOutPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutAlloc.run( + CallFrame, std::initializer_list{OutputInOutPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t InputInOutId = readUInt32(MemInst, InputInOutPtr); + ASSERT_TRUE(InputInOutId > 0); + + uint32_t OutputInOutId = readUInt32(MemInst, OutputInOutPtr); + ASSERT_TRUE(OutputInOutId > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_get_filter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphGetFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphGetFilter &>( + FuncInst->getHostFunc()); + + { + int32_t Length = OutputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterGraphGetFilter.run( + CallFrame, + std::initializer_list{ + OutputFilterCtxPtr, FilterGraphId, OutputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + Length = InputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterGraphGetFilter.run( + CallFrame, + std::initializer_list{ + InputFilterCtxPtr, FilterGraphId, InputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t OutputFilterCtxId = readUInt32(MemInst, OutputFilterCtxPtr); + ASSERT_TRUE(OutputFilterCtxId > 0); + + uint32_t InputFilterCtxId = readUInt32(MemInst, InputFilterCtxPtr); + ASSERT_TRUE(InputFilterCtxId > 0); + + // ================================================================== + // Setting InOutId Values for Filtering + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetName &>( + FuncInst->getHostFunc()); + + { + int32_t Length = InputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterInOutSetName.run( + CallFrame, + std::initializer_list{OutputInOutId, + InputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + Length = OutputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterInOutSetName.run( + CallFrame, + std::initializer_list{ + InputInOutId, OutputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_filter_ctx"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetFilterCtx = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetFilterCtx &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutSetFilterCtx.run( + CallFrame, + std::initializer_list{OutputInOutId, + InputFilterCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutSetFilterCtx.run( + CallFrame, + std::initializer_list{InputInOutId, + OutputFilterCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_pad_idx"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetPadIdx = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetPadIdx &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutSetPadIdx.run( + CallFrame, + std::initializer_list{OutputInOutId, 0}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutSetPadIdx.run( + CallFrame, std::initializer_list{InputInOutId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_next"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetNext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetNext &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutSetNext.run( + CallFrame, + std::initializer_list{OutputInOutId, 0}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutSetNext.run( + CallFrame, std::initializer_list{InputInOutId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // End Setting InOutId Values for Filtering + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_parse_ptr"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphParsePtr = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphParsePtr &>( + FuncInst->getHostFunc()); + + { + int32_t Length = SpecStr.length(); + EXPECT_TRUE(HostFuncAVFilterGraphParsePtr.run( + CallFrame, + std::initializer_list{ + FilterGraphId, SpecPtr, Length, InputInOutId, OutputInOutId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_config"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphConfig = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphConfig &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphConfig.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_dump_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphDumpLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphDumpLength &>( + FuncInst->getHostFunc()); + + int32_t GraphStrLen = 0; + { + EXPECT_TRUE(HostFuncAVFilterGraphDumpLength.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + GraphStrLen = Result[0].get(); + ASSERT_TRUE(GraphStrLen > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_dump"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphDump = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphDump &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphDump.run( + CallFrame, + std::initializer_list{FilterGraphId, StrPtr, + GraphStrLen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // Crashing the program. Checked even from Rust side. + + // FuncInst = AVFilterMod->findFuncExports( + // "wasmedge_ffmpeg_avfilter_avfilter_inout_free"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // + // auto &HostFuncAVFilterInOutFree = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutFree &>( + // FuncInst->getHostFunc()); + // + // { + // EXPECT_TRUE(HostFuncAVFilterInOutFree.run( + // CallFrame, + // std::initializer_list{InputInOutId}, Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // + // EXPECT_TRUE(HostFuncAVFilterInOutFree.run( + // CallFrame, + // std::initializer_list{OutputInOutId}, + // Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterVersion.run( + CallFrame, std::initializer_list{}, Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterConfigurationLength &>( + FuncInst->getHostFunc()); + + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVFilterConfigurationLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterConfiguration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterConfiguration.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterLicenseLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterLicenseLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // Start Test AVBufferSource, AVBufferSink Funcs + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersrc_get_nb_failed_requests"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSrcGetNbFailedRequests = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSrcGetNbFailedRequests + &>(FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVBufferSrcGetNbFailedRequests.run( + CallFrame, + std::initializer_list{InputFilterCtxId}, Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersrc_add_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSrcAddFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSrcAddFrame &>( + FuncInst->getHostFunc()); + + // Returning Error Code -22 (Invalid Argument), Due to Passing Empty Frame. + { + EXPECT_TRUE(HostFuncAVBufferSrcAddFrame.run( + CallFrame, + std::initializer_list{InputFilterCtxId, FrameId}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + // Need to send the last frame. Then only this test will pass. Else Null + // pointer exception. + // FuncInst = AVFilterMod->findFuncExports( + // "wasmedge_ffmpeg_avfilter_av_buffersrc_close"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // + // auto &HostFuncAVBufferSrcClose = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSrcClose &>( + // FuncInst->getHostFunc()); + // + // { + // int64_t Pts = 20; + // uint32_t Flags = 30; + // EXPECT_TRUE(HostFuncAVBufferSrcClose.run( + // CallFrame, + // std::initializer_list{InputFilterCtxPtr, Pts, + // Flags}, + // Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + + // Passing Empty frames. Return AVERROR due to no frames presen Return AVERROR + // due to no frames present. + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersink_get_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSinkGetFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSinkGetFrame &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVBufferSinkGetFrame.run( + CallFrame, + std::initializer_list{OutputFilterCtxId, FrameId}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersink_get_samples"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSinkGetSamples = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSinkGetSamples &>( + FuncInst->getHostFunc()); + + // Passing Empty frames. Return AVERROR due to no frames presen Return AVERROR + // due to no frames present. + { + EXPECT_TRUE(HostFuncAVBufferSinkGetSamples.run( + CallFrame, + std::initializer_list{OutputFilterCtxId, FrameId, + 20}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersink_set_frame_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAvBufferSinkSetFrameSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AvBufferSinkSetFrameSize &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAvBufferSinkSetFrameSize.run( + CallFrame, + std::initializer_list{OutputFilterCtxId, 30}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // End Test AVBufferSource, AVBufferSink Funcs + // ================================================================== + + // ================================================================== + // Clean Memory + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_free_graph_str"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterFreeGraphStr = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterFreeGraphStr &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterFreeGraphStr.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_drop"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterDrop = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterDrop.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_context_drop"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterContextDrop = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterContextDrop &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterContextDrop.run( + CallFrame, + std::initializer_list{InputFilterCtxId}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterContextDrop.run( + CallFrame, + std::initializer_list{OutputFilterCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphFree = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphFree &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphFree.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // End Clean Memory + // ================================================================== +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp new file mode 100644 index 00000000..299d26f9 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avformat/avChapter.h" +#include "avformat/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Sample Video under test has only Single Chapter. +TEST_F(FFmpegTest, AVChapter) { + ASSERT_TRUE(AVFormatMod != nullptr); + + uint32_t ChapterIdx = 0; + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t NumPtr = UINT32_C(12); + uint32_t DenPtr = UINT32_C(16); + uint32_t DictionaryPtr = UINT32_C(20); + uint32_t FilePtr = UINT32_C(100); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFormatCtx(FormatCtxPtr, FilePtr, FileName); + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + auto *FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterId = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterId.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterTimebase &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterTimebase.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + EXPECT_EQ(readSInt32(MemInst, NumPtr), 1); + EXPECT_TRUE(readSInt32(MemInst, DenPtr) >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_start"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterStart = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterStart.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_end"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterEnd = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterEnd.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterMetadata &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx, + DictionaryPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, DictionaryPtr) > 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_set_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetId = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t ChapterId = 10000; + EXPECT_TRUE( + HostFuncAVChapterSetId.run(CallFrame, + std::initializer_list{ + FormatCtxId, ChapterIdx, ChapterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterId.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), ChapterId); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_set_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterSetTimebase &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + EXPECT_TRUE(HostFuncAVChapterSetTimebase.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterTimebase.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readSInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readSInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_set_start"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetStart = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterSetStart &>( + FuncInst->getHostFunc()); + + { + int64_t StartValue = 1000; + EXPECT_TRUE(HostFuncAVChapterSetStart.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx, + StartValue}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterStart.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), StartValue); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_set_end"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetEnd = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t EndValue = 99999; + EXPECT_TRUE( + HostFuncAVChapterSetEnd.run(CallFrame, + std::initializer_list{ + FormatCtxId, ChapterIdx, EndValue}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterEnd.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), EndValue); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp new file mode 100644 index 00000000..99a1fcbe --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avformat/avInputOutputFormat.h" +#include "avformat/avformatContext.h" +#include "avformat/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVInputFormat) { + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + uint32_t FormatCtxPtr = UINT32_C(24); + uint32_t InputFormatPtr = UINT32_C(28); + + uint32_t StrBuf = UINT32_C(100); + initFFmpegStructs(UINT32_C(20), FormatCtxPtr, UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), UINT32_C(72)); + + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + // ==================================================================== + // Initialize AVInputFormat + // ==================================================================== + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_iformat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxIFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxIFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxIFormat.run( + CallFrame, + std::initializer_list{FormatCtxId, + InputFormatPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, InputFormatPtr) > 0); + } + uint32_t InputFormatId = readUInt32(MemInst, InputFormatPtr); + + // ==================================================================== + // End Initialize AVInputFormat + // ==================================================================== + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatNameLength &>( + FuncInst->getHostFunc()); + + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVIOFormatNameLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatName &>( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, StrBuf, Length); + { + EXPECT_TRUE(HostFuncAVInputFormatName.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_long_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatLongNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatLongNameLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVIOFormatLongNameLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_long_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatLongName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatLongName &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputFormatLongName.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_extensions_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatExtensionsLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatExtensionsLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVIOFormatExtensionsLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_extensions"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatExtensions = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatExtensions &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputFormatExtensions.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_mime_type_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatMimeTypeLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatMimeTypeLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVIOFormatMimeTypeLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_mime_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatMimeType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatMimeType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputFormatMimeType.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputOutputFormat_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputOutputFormatFree = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputOutputFormatFree &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputOutputFormatFree.run( + CallFrame, std::initializer_list{InputFormatId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp new file mode 100644 index 00000000..0cc0655b --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avformat/avStream.h" +#include "avformat/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Testing all AVFormat_funcs. +TEST_F(FFmpegTest, AVStreamStruct) { + ASSERT_TRUE(AVFormatMod != nullptr); + + uint32_t StreamIdx = 0; + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t CodecParameterPtr = UINT32_C(8); + uint32_t NumPtr = UINT32_C(12); + uint32_t DenPtr = UINT32_C(16); + uint32_t DictPtr = UINT32_C(20); + uint32_t FilePtr = UINT32_C(100); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFormatCtx(FormatCtxPtr, FilePtr, FileName); + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + auto *FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avStream_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamId = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t AvFormatCtxId = readUInt32(MemInst, FormatCtxPtr); + { + EXPECT_TRUE(HostFuncAVStreamId.run( + CallFrame, + std::initializer_list{AvFormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avStream_index"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamIndex = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamIndex.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_codecpar"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamCodecPar = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamCodecPar &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamCodecPar.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + CodecParameterPtr}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + ASSERT_TRUE(readUInt32(MemInst, CodecParameterPtr) > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamTimebase &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetTimebase &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + EXPECT_TRUE(HostFuncAVStreamSetTimebase.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVStreamTimebase.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readUInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readUInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamDuration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamDuration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamDuration.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_start_time"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamStartTime = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamStartTime &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamStartTime.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_nb_frames"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamNbFrames = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamNbFrames &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamNbFrames.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_disposition"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamDisposition = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamDisposition &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamDisposition.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_r_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetRFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetRFrameRate &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_r_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamRFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamRFrameRate &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + EXPECT_TRUE(HostFuncAVStreamSetRFrameRate.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVStreamRFrameRate.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readUInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readUInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_avg_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetAvgFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetAvgFrameRate &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_avg_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamAvgFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamAvgFrameRate &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + + EXPECT_TRUE(HostFuncAVStreamSetAvgFrameRate.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVStreamAvgFrameRate.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readUInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readUInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamMetadata &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetMetadata &>( + FuncInst->getHostFunc()); + { + EXPECT_TRUE(HostFuncAVStreamMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + DictPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVStreamSetMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avStream_discard"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamDiscard = + dynamic_cast( + FuncInst->getHostFunc()); + { + EXPECT_TRUE(HostFuncAVStreamDiscard.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp new file mode 100644 index 00000000..3d23d403 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avformat/avformatContext.h" +#include "avformat/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Testing all AVFormat_funcs. +TEST_F(FFmpegTest, AVFormatContextStruct) { + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t InputFormatPtr = UINT32_C(8); + uint32_t OutputFormatPtr = UINT32_C(12); + uint32_t DicPtr = uint32_t(16); + uint32_t FilePtr = UINT32_C(100); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFormatCtx(FormatCtxPtr, FilePtr, FileName); + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_iformat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxIFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxIFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxIFormat.run( + CallFrame, + std::initializer_list{FormatCtxId, + InputFormatPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, InputFormatPtr) > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_oformat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxOFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxOFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxOFormat.run( + CallFrame, + std::initializer_list{FormatCtxId, + OutputFormatPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, InputFormatPtr) > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_probescope"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxProbeScore = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxProbeScore &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxProbeScore.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 100); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_nb_streams"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxNbStreams = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxNbStreams &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxNbStreams.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxDuration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxDuration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxDuration.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0 || Result[0].get() < 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_bit_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxBitRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxBitRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxBitRate.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_set_nb_chapters"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxSetNbChapters = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxSetNbChapters &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_nb_chapters"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxNbChapters = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxNbChapters &>( + FuncInst->getHostFunc()); + { + uint32_t NbChapters = 200; + EXPECT_TRUE(HostFuncAVFormatCtxSetNbChapters.run( + CallFrame, + std::initializer_list{FormatCtxId, NbChapters}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFormatCtxNbChapters.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), NbChapters); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxMetadata &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_set_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxSetMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxSetMetadata &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, DicPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, DicPtr) > 0); + + uint32_t DictId = readUInt32(MemInst, DicPtr); + EXPECT_TRUE(HostFuncAVFormatCtxSetMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp new file mode 100644 index 00000000..809a25d9 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -0,0 +1,591 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avformat/avformat_func.h" +#include "avformat/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Testing all AVFormat_funcs. +TEST_F(FFmpegTest, AVInputFormatFunc) { + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t DictPtr = UINT32_C(16); + uint32_t KeyPtr = UINT32_C(100); + uint32_t ValuePtr = UINT32_C(200); + uint32_t StrPtr = UINT32_C(400); + + initDict(DictPtr, KeyPtr, std::string("Key"), ValuePtr, std::string("Value")); + uint32_t DictId = readUInt32(MemInst, DictPtr); + + uint32_t UrlStart = UINT32_C(300); + uint32_t UrlSize = 30; + fillMemContent(MemInst, UrlStart, UrlSize); + fillMemContent(MemInst, UrlStart, + std::string("ffmpeg-assets/sample_video.mp4")); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_open_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatOpenInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatOpenInput &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatOpenInput"sv); + { + // AVDict only + EXPECT_TRUE(HostFuncAVFormatOpenInput.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, UrlStart, UrlSize, UINT32_C(0), DictId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + EXPECT_TRUE(readUInt32(MemInst, FormatCtxPtr) > 0); + + // No AVDict, No AVInputFormat + EXPECT_TRUE(HostFuncAVFormatOpenInput.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, UrlStart, UrlSize, UINT32_C(0), UINT32_C(0)}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + EXPECT_TRUE(readUInt32(MemInst, FormatCtxPtr) > 0); + } + + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_find_stream_info"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatFindStreamInfo = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatFindStreamInfo &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatFindStreamInfo"sv); + { + EXPECT_TRUE(HostFuncAVFormatFindStreamInfo.run( + CallFrame, + std::initializer_list{FormatCtxId, UINT32_C(0)}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_dump_format"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVDumpFormat = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVDumpFormat"sv); + { + EXPECT_TRUE(HostFuncAVFormatAVDumpFormat.run( + CallFrame, + std::initializer_list{ + FormatCtxId, 0, UINT32_C(100), UINT32_C(30), 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_av_find_best_stream"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFindBestStream = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFindBestStream &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFindBestStream"sv); + { + EXPECT_TRUE(HostFuncAVFindBestStream.run( + CallFrame, + std::initializer_list{ + FormatCtxId, UINT32_C(0), INT32_C(-1), INT32_C(-1), 0, 0}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVReadFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVReadFrame"sv); + { + uint32_t PacketPtr = UINT32_C(520); + allocPacket(PacketPtr); + uint32_t PacketId = readUInt32(MemInst, PacketPtr); + EXPECT_TRUE(HostFuncAVReadFrame.run( + CallFrame, + std::initializer_list{FormatCtxId, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_network_init"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatNetworkInit = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatNetworkInit &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatNetworkInit"sv); + { + EXPECT_TRUE(HostFuncAVFormatNetworkInit.run( + CallFrame, std::initializer_list{}, Result)); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_seek_file"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatSeekFile = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatSeekFile &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatSeekFile"sv); + { + uint32_t StreamIdx = -1; + int64_t MinTs = -10; + int64_t Ts = 0; + int64_t MaxTs = 10; + int32_t Flags = 0; + + // Try a network Fetch. + EXPECT_TRUE(HostFuncAVFormatSeekFile.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + MinTs, Ts, MaxTs, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_play"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVReadPlay = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVReadPlay"sv); + { + // Try a network Fetch. + EXPECT_TRUE(HostFuncAVFormatAVReadPlay.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() < 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_pause"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVReadPause = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVReadPause"sv); + { + // Try a network Fetch. + EXPECT_TRUE(HostFuncAVFormatAVReadPause.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() < 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_network_deinit"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatNetworkDeInit = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatNetworkDeInit &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatNetworkDeInit"sv); + { + EXPECT_TRUE(HostFuncAVFormatNetworkDeInit.run( + CallFrame, std::initializer_list{}, Result)); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_close_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCloseInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCloseInput &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatCloseInput"sv); + { + EXPECT_TRUE(HostFuncAVFormatCloseInput.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_free_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFreeContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatFreeContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatFreeContext"sv); + { + EXPECT_TRUE(HostFuncAVFreeContext.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avformat_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatVersion"sv); + { + EXPECT_TRUE(HostFuncAVFormatVersion.run( + CallFrame, std::initializer_list{}, Result)); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatConfigurationLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatConfigurationLength"sv); + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVFormatConfigurationLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatConfiguration &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatConfiguration"sv); + { + EXPECT_TRUE(HostFuncAVFormatConfiguration.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatLicenseLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatLicenseLength"sv); + { + EXPECT_TRUE(HostFuncAVFormatLicenseLength.run( + CallFrame, std::initializer_list{}, Result)); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avformat_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatLicense"sv); + { + EXPECT_TRUE(HostFuncAVFormatLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + + EXPECT_TRUE(Result[0].get() >= 0); + } +} + +TEST_F(FFmpegTest, AVOutputFormatFunc) { + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t DictPtr = UINT32_C(16); + uint32_t ChapterPtr = UINT32_C(20); + uint32_t FramePtr = UINT32_C(24); + uint32_t KeyPtr = UINT32_C(100); + uint32_t ValuePtr = UINT32_C(200); + + initDict(DictPtr, KeyPtr, std::string("Key"), ValuePtr, std::string("Value")); + initEmptyFrame(FramePtr); + uint32_t DictId = readUInt32(MemInst, DictPtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + + uint32_t FormatStart = 300; + uint32_t FormatLen = 3; + uint32_t FileStart = 350; + uint32_t FileLen = 8; + fillMemContent(MemInst, FormatStart, FormatLen + FileLen); + + fillMemContent(MemInst, FormatStart, "mp4"sv); + fillMemContent(MemInst, FileStart, "test.mp4"sv); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_alloc_output_context2"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAllocOutputContext2 = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatAllocOutputContext2 &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatAllocOutputContext2"sv); + { + EXPECT_TRUE(HostFuncAVFormatAllocOutputContext2.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, 0, FormatStart, FormatLen, FileStart, FileLen}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + + EXPECT_TRUE(HostFuncAVFormatAllocOutputContext2.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, readUInt32(MemInst, FormatCtxPtr), FormatStart, + FormatLen, FileStart, FileLen}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + + EXPECT_TRUE(HostFuncAVFormatAllocOutputContext2.run( + CallFrame, + std::initializer_list{FormatCtxPtr, 0, 0, 0, + FileStart, FileLen}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avio_open"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOOpen = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVIOOpen"sv); + { + uint32_t AvFormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE( + HostFuncAVIOOpen.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, FileStart, FileLen, 2}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avio_open2"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOOpen2 = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVIOOpen2"sv); + { + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVIOOpen2.run( + CallFrame, + std::initializer_list{FormatCtxId, FileStart, + FileLen, 2, 0, DictId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + // TODO: This test modifies the input file, so it cannot be tested. + // Added test on the Rust side. + // spdlog::info("Testing AVGuessCodec"sv); + // uint32_t EmptyStrPtr = UINT32_C(520); + // writeUInt32(MemInst, 0, EmptyStrPtr); + // { + // uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + // int32_t MediaTypeId = 0; // Video + // EXPECT_TRUE(HostFuncAVGuessCodec.run( + // CallFrame, + // std::initializer_list{FormatCtxId, + // EmptyStrPtr, 0, + // FilePtr, 32, + // EmptyStrPtr, 0, + // MediaTypeId}, + // Result)); + // EXPECT_EQ(Result[0].get(), 1); // AV_CODEC_ID_MPEG1VIDEO: + // } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_write_header"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatWriteHeader = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatWriteHeader &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatWriteHeader"sv); + { + // Did not set AVParameters, etc. Hence Giving Invalid Argument Error. + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVFormatWriteHeader.run( + CallFrame, + std::initializer_list{FormatCtxId, DictId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + // Writing the header above returns an invalid argument, so the test below + // does not work. The OutputFormatContext should be configured using the input + // format context. Test this on the Rust side. This is working as expected. + + // FuncInst = AVFormatMod->findFuncExports( + // "wasmedge_ffmpeg_avformat_avformat_write_trailer"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // auto &HostFuncAVFormatTrailer = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatWriteTrailer &>( + // FuncInst->getHostFunc()); + // { + // // Did not set AVParameters, etc. Hence Giving Invalid Argument Error. + // uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + // EXPECT_TRUE(HostFuncAVFormatTrailer.run( + // CallFrame, std::initializer_list{FormatCtxId}, + // Result)); + // EXPECT_EQ(Result[0].get(), -22); + // } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avchapter_mallocz"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOClose = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterMallocz &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVChapterMallocz"sv); + { + EXPECT_TRUE(HostFuncAVIOClose.run( + CallFrame, std::initializer_list{ChapterPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + // How to pass IntPtr + // + // FuncInst = AVFormatMod->findFuncExports( + // "wasmedge_ffmpeg_avformat_avchapter_dynarray_add"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // auto &HostFuncAVChapterDynarrayAdd = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterDynarrayAdd &>( + // FuncInst->getHostFunc()); + // // For the given input file, nb_chapter is 0; + // { + // uint32_t AvChapterId = readUInt32(MemInst, AvFormatCtxPtr); + // uint32_t AvFormatCtxId = readUInt32(MemInst, AvFormatCtxPtr); + // EXPECT_TRUE(HostFuncAVChapterDynarrayAdd.run( + // CallFrame, + // std::initializer_list{AvFormatCtxId, + // UINT32_C(0), + // AvChapterId}, + // Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avformat_avfreep"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVFreep = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFreeP"sv); + { + uint32_t ChapterId = readUInt32(MemInst, ChapterPtr); + EXPECT_TRUE(HostFuncAVFormatAVFreep.run( + CallFrame, std::initializer_list{ChapterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_write_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVWriteFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVWriteFrame"sv); + // Passing Empty Frame, Hence giving Invalid Argument Error. + { + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVWriteFrame.run( + CallFrame, + std::initializer_list{FormatCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_av_interleaved_write_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInterleavedWriteFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInterleavedWriteFrame &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVInterleavedWriteFrame"sv); + // Passing Empty Frame, Hence giving Invalid Argument Error. + { + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVInterleavedWriteFrame.run( + CallFrame, + std::initializer_list{FormatCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp new file mode 100644 index 00000000..779fb236 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil/avDictionary.h" +#include "avutil/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVDictionary) { + using namespace std::literals::string_view_literals; + + uint32_t KeyStart = UINT32_C(1); + uint32_t KeyLen = 3; + uint32_t ValueStart = UINT32_C(4); + uint32_t ValueLen = 5; + uint32_t PrevDictEntryIdx = 0; // The Fetch the next Key value Node using an + // index. Passing Index from Rust side. + int32_t Flags = 0; + uint32_t NullDictId = UINT32_C(0); + + uint32_t DictPtr = UINT32_C(80); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_set"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictSet = + dynamic_cast( + FuncInst->getHostFunc()); + + // Fill 0 in WasmMemory. + fillMemContent(MemInst, KeyStart, KeyLen + ValueLen); + fillMemContent(MemInst, KeyStart, "KEY"sv); + fillMemContent(MemInst, ValueStart, "VALUE"sv); + + // Storing the above Key and Value in dict and using these in below tests + // (dict_get) to fetch Key,values. + { + EXPECT_TRUE(HostFuncAVDictSet.run( + CallFrame, + std::initializer_list{ + DictPtr, KeyStart, KeyLen, ValueStart, ValueLen, Flags}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + ASSERT_TRUE(readUInt32(MemInst, DictPtr) > 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_copy"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictCopy = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t DestDictPtr = UINT32_C(80); + uint32_t SrcDictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE( + HostFuncAVDictCopy.run(CallFrame, + std::initializer_list{ + DestDictPtr, SrcDictId, Flags}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_get"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictGet = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Store the string lengths of Key and value in the pointers below. + uint32_t KeyLenPtr = UINT32_C(56); + uint32_t ValueLenPtr = UINT32_C(60); + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVDictGet.run( + CallFrame, + std::initializer_list{DictId, KeyStart, KeyLen, + PrevDictEntryIdx, Flags, + KeyLenPtr, ValueLenPtr}, + Result)); + EXPECT_TRUE(Result[0].get() == 1); + EXPECT_EQ(readUInt32(MemInst, KeyLenPtr), KeyLen); + EXPECT_EQ(readUInt32(MemInst, ValueLenPtr), ValueLen); + + // Pass a Null Dict and testing. + EXPECT_TRUE(HostFuncAVDictGet.run( + CallFrame, + std::initializer_list{ + NullDictId, KeyStart, KeyLen, PrevDictEntryIdx, Flags, KeyLenPtr, + ValueLenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), + static_cast(ErrNo::InternalError)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_dict_get_key_value"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictGetKeyValue = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Store the strings of Key and value in the buffer pointers below. + uint32_t KeyBufPtr = UINT32_C(36); + uint32_t ValueBufPtr = UINT32_C(40); + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVDictGetKeyValue.run( + CallFrame, + std::initializer_list{ + DictId, KeyStart, KeyLen, ValueBufPtr, ValueLen, KeyBufPtr, + UINT32_C(3), PrevDictEntryIdx, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + // Verify String. Read String from MemInst + + // Pass a Null Dict and testing. + EXPECT_TRUE(HostFuncAVDictGetKeyValue.run( + CallFrame, + std::initializer_list{ + NullDictId, KeyStart, KeyLen, ValueBufPtr, ValueLen, KeyBufPtr, + UINT32_C(3), PrevDictEntryIdx, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), + static_cast(ErrNo::InternalError)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictFree = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVDictFree.run( + CallFrame, std::initializer_list{DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp new file mode 100644 index 00000000..8584c97b --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil/error.h" +#include "avutil/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVError) { + using namespace std::literals::string_view_literals; + ASSERT_TRUE(AVUtilMod != nullptr); + + int32_t ErrNum = 35; + uint32_t ErrStartPtr = UINT32_C(100); + uint32_t ErrSize = 10; + fillMemContent(MemInst, ErrStartPtr, ErrSize); + fillMemContent(MemInst, ErrStartPtr, "Test Error"sv); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_strerror"); + auto &HostFuncAVUtilAVStrError = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilAVStrError.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_AVERROR"); + auto &HostFuncAVUtilAVError = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilAVError.run( + CallFrame, std::initializer_list{ErrNum}, Result); + + EXPECT_EQ(Result[0].get(), + ErrNum * -1); // Returns Negative, convert to Positive + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_AVUNERROR"); + auto &HostFuncAVUtilAVUNError = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilAVUNError.run( + CallFrame, std::initializer_list{ErrNum}, Result); + + EXPECT_EQ(Result[0].get(), + ErrNum * -1); // Returns Negative, convert to Positive + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp new file mode 100644 index 00000000..dd20f625 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -0,0 +1,758 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil/avFrame.h" +#include "avutil/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVFrame) { + uint32_t AVFramePtr = UINT32_C(72); + uint32_t AVFrame2Ptr = UINT32_C(40); + uint32_t DictPtr = UINT32_C(36); + uint32_t NumPtr = UINT32_C(80); + uint32_t DenPtr = UINT32_C(84); + uint32_t BufPtr = UINT32_C(200); // TO store Frame Data; + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(12), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), AVFramePtr); + + initFFmpegStructs(UINT32_C(100), UINT32_C(104), UINT32_C(108), FileName, + UINT32_C(112), UINT32_C(116), UINT32_C(120), AVFrame2Ptr); + + uint32_t AVFrameId = readUInt32(MemInst, AVFramePtr); + uint32_t AVFrame2Id = readUInt32(MemInst, AVFrame2Ptr); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_alloc"); + auto &HostFuncAVFrameAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t EmptyFramePtr = UINT32_C(64); + + { + HostFuncAVFrameAlloc.run( + CallFrame, std::initializer_list{EmptyFramePtr}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, EmptyFramePtr) > 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_free"); + auto &HostFuncAVFrameFree = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t EmptyFrameId = readUInt32(MemInst, EmptyFramePtr); + HostFuncAVFrameFree.run( + CallFrame, std::initializer_list{EmptyFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_width"); + auto &HostFuncAVFrameWidth = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameWidth.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 1920); // Width + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_height"); + auto &HostFuncAVFrameHeight = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameHeight.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 1080); // Height + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_video_format"); + auto &HostFuncAVFrameVideoFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameVideoFormat &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameVideoFormat.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 1); // Video Format + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_isnull"); + auto &HostFuncAVFrameIsNull = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameIsNull.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_linesize"); + auto &HostFuncAVFrameLinesize = + dynamic_cast( + FuncInst->getHostFunc()); + + int32_t Stride = 0; + uint32_t Idx = 0; + { + HostFuncAVFrameLinesize.run( + CallFrame, std::initializer_list{AVFrameId, Idx}, + Result); + + Stride = Result[0].get(); + EXPECT_EQ(Stride, 1920); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_get_buffer"); + auto &HostFuncAVFrameGetBuffer = + dynamic_cast( + FuncInst->getHostFunc()); + { + // For video, it is 32. + int32_t Align = 32; + HostFuncAVFrameGetBuffer.run( + CallFrame, + std::initializer_list{AVFrameId, Align}, Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_best_effort_timestamp"); + auto &HostFuncAVFrameBestEffortTimestamp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameBestEffortTimestamp &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameBestEffortTimestamp.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_pict_type"); + auto &HostFuncAVFramePictType = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFramePictType.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_interlaced_frame"); + auto &HostFuncAVFrameInterlacedFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameInterlacedFrame &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameInterlacedFrame.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() == 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_top_field_first"); + auto &HostFuncAVFrameTopFieldFirst = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameTopFieldFirst &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameTopFieldFirst.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() == 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_palette_has_changed"); + auto &HostFuncAVFramePaletteHasChanged = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFramePaletteHasChanged &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFramePaletteHasChanged.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() == 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_colorspace"); + auto &HostFuncAVFrameColorspace = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameColorspace.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 2); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_color_range"); + auto &HostFuncAVFrameColorRange = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameColorRange.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_color_trc"); + auto &HostAVFrameColorTransferCharacteristic = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameColorTransferCharacteristic + &>(FuncInst->getHostFunc()); + + { + HostAVFrameColorTransferCharacteristic.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 2); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_chroma_location"); + auto &HostAVFrameChromaLocation = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameChromaLocation &>( + FuncInst->getHostFunc()); + + { + HostAVFrameChromaLocation.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_repeat_pict"); + auto &HostAVFrameRepeatPict = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameRepeatPict.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_flags"); + auto &HostAVFrameFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameFlags.run(CallFrame, + std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() != 1 << 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_quality"); + auto &HostAVFrameQuality = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameQuality.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_metadata"); + auto &HostAVFrameMetadata = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameMetadata.run( + CallFrame, + std::initializer_list{AVFrameId, DictPtr}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + uint32_t DictId = readUInt32(MemInst, DictPtr); + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_metadata"); + auto &HostAVFrameSetMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetMetadata &>( + FuncInst->getHostFunc()); + + { + HostAVFrameSetMetadata.run( + CallFrame, + std::initializer_list{AVFrameId, DictId}, Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_key_frame"); + auto &HostAVFrameKeyFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameKeyFrame.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_pts"); + auto &HostAVFramePts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFramePts.run(CallFrame, + std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_copy"); + auto &HostAVFrameCopy = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameCopy.run( + CallFrame, + std::initializer_list{AVFrame2Id, AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_copy_props"); + auto &HostAVFrameCopyProps = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameCopyProps.run( + CallFrame, + std::initializer_list{AVFrame2Id, AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_set_width"); + auto &HostFuncAVFrameSetWidth = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t Width = 100; + HostFuncAVFrameSetWidth.run( + CallFrame, + std::initializer_list{AVFrameId, Width}, Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameWidth.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Width); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_set_height"); + auto &HostFuncAVFrameSetHeight = + dynamic_cast( + FuncInst->getHostFunc()); + + int32_t Height = 100; + { + HostFuncAVFrameSetHeight.run( + CallFrame, + std::initializer_list{AVFrameId, Height}, Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameHeight.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Height); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_data"); + auto &HostFuncAVFrameData = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t Size = 1; // Just reading One byte data for test. + fillMemContent(MemInst, BufPtr, Size); + HostFuncAVFrameData.run(CallFrame, + std::initializer_list{ + AVFrameId, BufPtr, Size, Idx}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_video_format"); + auto &HostFuncAVFrameSetVideoFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetVideoFormat &>( + FuncInst->getHostFunc()); + + { + uint32_t PixFormatId = 10; // GRAY8 + HostFuncAVFrameSetVideoFormat.run( + CallFrame, + std::initializer_list{AVFrameId, PixFormatId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameVideoFormat.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), PixFormatId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_pict_type"); + auto &HostFuncAVFrameSetPictType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetPictType &>( + FuncInst->getHostFunc()); + + { + int32_t PictureId = 4; // AV_PICTURE_TYPE_S + HostFuncAVFrameSetPictType.run( + CallFrame, + std::initializer_list{AVFrameId, PictureId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFramePictType.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), PictureId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_colorspace"); + auto &HostFuncAVFrameSetColorSpace = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetColorSpace &>( + FuncInst->getHostFunc()); + + { + int32_t ColorSpaceId = 4; // FCC + HostFuncAVFrameSetColorSpace.run( + CallFrame, + std::initializer_list{AVFrameId, ColorSpaceId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameColorspace.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorSpaceId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_color_range"); + auto &HostFuncAVFrameSetColorRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetColorRange &>( + FuncInst->getHostFunc()); + + { + int32_t ColorRangeId = 1; // MPEG + HostFuncAVFrameSetColorRange.run( + CallFrame, + std::initializer_list{AVFrameId, ColorRangeId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameColorRange.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorRangeId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_color_trc"); + auto &HostFuncAVFrameSetColorTransferCharacteristic = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ColorTrcId = 5; // GAMMA28 + HostFuncAVFrameSetColorTransferCharacteristic.run( + CallFrame, + std::initializer_list{AVFrameId, ColorTrcId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostAVFrameColorTransferCharacteristic.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorTrcId); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_set_pts"); + auto &HostFuncAVFrameSetPts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t Pts = 10; + HostFuncAVFrameSetPts.run( + CallFrame, std::initializer_list{AVFrameId, Pts}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostAVFramePts.run(CallFrame, + std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Pts); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_sample_aspect_ratio"); + auto &HostFuncAVFrameSampleAspectRatio = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSampleAspectRatio &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameSampleAspectRatio.run( + CallFrame, + std::initializer_list{AVFrameId, NumPtr, DenPtr}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorPrimariesId = 1; // BT709 + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_color_primaries"); + auto &HostFuncAVFrameSetColorPrimaries = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetColorPrimaries &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameSetColorPrimaries.run( + CallFrame, + std::initializer_list{AVFrameId, + ColorPrimariesId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_color_primaries"); + auto &HostFuncAVFrameColorPrimaries = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameColorPrimaries &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameColorPrimaries.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorPrimariesId); + } + + // ========================================================================== + // AVFrame Audio Funcs. + // ========================================================================== + + // Setting the fields to Video Frame itself. + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_audio_format"); + auto &HostFuncAVFrameSetAudioFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetAudioFormat &>( + FuncInst->getHostFunc()); + + uint32_t SampleFormatId = 4; + { + HostFuncAVFrameSetAudioFormat.run( + CallFrame, + std::initializer_list{AVFrameId, SampleFormatId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_audio_format"); + auto &HostFuncAVFrameAudioFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameAudioFormat &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameAudioFormat.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), SampleFormatId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_nb_samples"); + auto &HostFuncAVFrameSetNbSamples = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetNbSamples &>( + FuncInst->getHostFunc()); + + int32_t NbSamples = 32; + { + HostFuncAVFrameSetNbSamples.run( + CallFrame, + std::initializer_list{AVFrameId, NbSamples}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_nb_samples"); + auto &HostFuncAVFrameNbSamples = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameNbSamples.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), NbSamples); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_sample_rate"); + auto &HostFuncAVFrameSetSampleRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetSampleRate &>( + FuncInst->getHostFunc()); + + int32_t SampleRate = 10; + { + HostFuncAVFrameSetSampleRate.run( + CallFrame, + std::initializer_list{AVFrameId, SampleRate}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_sample_rate"); + auto &HostFuncAVFrameSampleRate = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameSampleRate.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), SampleRate); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_channels"); + auto &HostFuncAVFrameSetChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetChannels &>( + FuncInst->getHostFunc()); + + int32_t Channels = 3; + { + HostFuncAVFrameSetChannels.run( + CallFrame, + std::initializer_list{AVFrameId, Channels}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_channels"); + auto &HostFuncAVFrameChannels = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameChannels.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Channels); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_channel_layout"); + auto &HostFuncAVFrameSetChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetChannelLayout &>( + FuncInst->getHostFunc()); + + uint64_t ChannelLayout = 1UL << 10; + { + HostFuncAVFrameSetChannelLayout.run( + CallFrame, + std::initializer_list{AVFrameId, ChannelLayout}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_channel_layout"); + auto &HostFuncAVFrameChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameChannelLayout &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameChannelLayout.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ChannelLayout); + } + + // ========================================================================== + // AVFrame Audio Funcs. + // ========================================================================== +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp new file mode 100644 index 00000000..625acd10 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil/module.h" +#include "avutil/pixfmt.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVPixFmt) { + uint32_t NamePtr = UINT32_C(4); + + auto *FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avpixfmtdescriptor_nb_components"); + auto &HostFuncAVPixFmtDescriptorNbComponents = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AvPixFmtDescriptorNbComponents &>( + FuncInst->getHostFunc()); + + uint32_t PixFmtId = 3; // RGB24 + + { + HostFuncAVPixFmtDescriptorNbComponents.run( + CallFrame, std::initializer_list{PixFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), PixFmtId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromaw"); + auto &HostFuncAvPixFmtDescriptorLog2ChromaW = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AvPixFmtDescriptorLog2ChromaW &>( + FuncInst->getHostFunc()); + + { + HostFuncAvPixFmtDescriptorLog2ChromaW.run( + CallFrame, std::initializer_list{1}, Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromah"); + auto &HostFuncAvPixFmtDescriptorLog2ChromaH = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AvPixFmtDescriptorLog2ChromaH &>( + FuncInst->getHostFunc()); + + { + HostFuncAvPixFmtDescriptorLog2ChromaH.run( + CallFrame, std::initializer_list{PixFmtId}, + Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + int32_t Length = 0; + int32_t TransferCharacteristicId = 6; // (SMPTE170M) + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_transfer_name_length"); + auto &HostFuncAVColorTransferNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorTransferNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorTransferNameLength.run( + CallFrame, + std::initializer_list{TransferCharacteristicId}, + Result); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_transfer_name"); + auto &HostFuncAVColorTransferName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorTransferName &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorTransferName.run( + CallFrame, + std::initializer_list{TransferCharacteristicId, + NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorRangeId = 2; //; JPEG + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_range_name_length"); + auto &HostFuncAVColorRangeNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorRangeNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorRangeNameLength.run( + CallFrame, std::initializer_list{ColorRangeId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_color_range_name"); + auto &HostFuncAVColorRangeName = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVColorRangeName.run(CallFrame, + std::initializer_list{ + ColorRangeId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorSpaceId = 1; // BT709 + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_space_name_length"); + auto &HostFuncAVColorSpaceNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorSpaceNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorSpaceNameLength.run( + CallFrame, std::initializer_list{ColorSpaceId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_color_space_name"); + auto &HostFuncAVColorSpaceName = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVColorSpaceName.run(CallFrame, + std::initializer_list{ + ColorSpaceId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorPrimariesId = 1; // BT709 + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_primaries_name_length"); + auto &HostFuncAVColorPrimariesNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorPrimariesNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorPrimariesNameLength.run( + CallFrame, + std::initializer_list{ColorPrimariesId}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_primaries_name"); + auto &HostFuncAVColorPrimariesName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorPrimariesName &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorPrimariesName.run( + CallFrame, + std::initializer_list{ColorPrimariesId, NamePtr, + Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + PixFmtId = 1; // YUV420P + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_pix_format_name_length"); + auto &HostFuncAVPixFormatNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVPixelFormatNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVPixFormatNameLength.run( + CallFrame, std::initializer_list{PixFmtId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_pix_format_name"); + auto &HostFuncAVPixFormatName = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVPixFormatName.run( + CallFrame, + std::initializer_list{PixFmtId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_pix_format_mask"); + auto &HostFuncAVPixFormatMask = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t PixId = 3; // AV_PIX_FMT_RGB24: + HostFuncAVPixFormatMask.run( + CallFrame, std::initializer_list{PixId}, Result); + + EXPECT_EQ(Result[0].get(), + 2); // Verify Mask. Position of Pix in Enum. + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp new file mode 100644 index 00000000..85223ad9 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil/avRational.h" +#include "avutil/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVRational) { + ASSERT_TRUE(AVUtilMod != nullptr); + + uint32_t NumPtr = UINT32_C(4); + uint32_t DenPtr = UINT32_C(8); + + // Addition Function + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_add_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVAddQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 3; + int32_t ADen = 4; + int32_t BNum = -6; + int32_t BDen = 7; + EXPECT_TRUE(HostFuncAVAddQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -3); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 28); + } + + // Subtraction Function + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_sub_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVSubQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = -843; + int32_t ADen = 11; + int32_t BNum = 38; + int32_t BDen = 12; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVSubQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -5267); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 66); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_mul_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVMulQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = -6; + int32_t ADen = 7; + int32_t BNum = 3; + int32_t BDen = 4; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVMulQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -9); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 14); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_div_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDivQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = -6; + int32_t ADen = 7; + int32_t BNum = 3; + int32_t BDen = 4; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVDivQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -8); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 7); + } + + // How to Pass a Double functions. + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_d2q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVD2Q = + dynamic_cast( + FuncInst->getHostFunc()); + + { + double D = 5; + int32_t Max = 10; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + + EXPECT_TRUE(HostFuncAVD2Q.run( + CallFrame, + std::initializer_list{D, Max, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), 5); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_q2d"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVQ2d = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Convert Rational Number to Double. + int32_t ANum = 1; + int32_t ADen = 2; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVQ2d.run( + CallFrame, std::initializer_list{ANum, ADen}, + Result)); + EXPECT_EQ(Result[0].get(), 0.5); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_inv_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInvQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Inverse a Rational Number. + int32_t ANum = -3; + int32_t ADen = 4; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncInvQ.run( + CallFrame, + std::initializer_list{ANum, ADen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), 4); + EXPECT_EQ(readSInt32(MemInst, DenPtr), -3); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_q2intfloat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVQ2IntFloat = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 1; + int32_t ADen = 5; + + EXPECT_TRUE(HostFuncAVQ2IntFloat.run( + CallFrame, std::initializer_list{ANum, ADen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1045220557)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_nearer_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVNearerQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 1; + int32_t ADen = 3; + int32_t BNum = 1; + int32_t BDen = 2; + int32_t CNum = -1; + int32_t CDen = 2; + + // B nearer to A + EXPECT_TRUE( + HostFuncAVNearerQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, CNum, CDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1)); + + ANum = -1; + + // C nearer to A + EXPECT_TRUE( + HostFuncAVNearerQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, CNum, CDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(-1)); + + ANum = 0; + ADen = 0; + + // Both are at same distance + EXPECT_TRUE( + HostFuncAVNearerQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, CNum, CDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(0)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_cmp_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVCmpQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 1; + int32_t ADen = 2; + int32_t BNum = 2; + int32_t BDen = 1; + // A < B + EXPECT_TRUE(HostFuncAVCmpQ.run( + CallFrame, + std::initializer_list{ANum, ADen, BNum, BDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(-1)); + + ANum = 2; + ADen = 1; + BNum = 1; + BDen = 2; + // A > B + EXPECT_TRUE(HostFuncAVCmpQ.run( + CallFrame, + std::initializer_list{ANum, ADen, BNum, BDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1)); + + ANum = 2; + ADen = 1; + BNum = 2; + BDen = 1; + // A == B + EXPECT_TRUE(HostFuncAVCmpQ.run( + CallFrame, + std::initializer_list{ANum, ADen, BNum, BDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(0)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_reduce"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVReduce = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t ANum = 1; + int64_t ADen = 2; + int64_t Max = 3; + EXPECT_TRUE( + HostFuncAVReduce.run(CallFrame, + std::initializer_list{ + NumPtr, DenPtr, ANum, ADen, Max}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1)); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp new file mode 100644 index 00000000..d12cd16d --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil/module.h" +#include "avutil/samplefmt.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVSampleFmt) { + using namespace std::literals::string_view_literals; + ASSERT_TRUE(AVUtilMod != nullptr); + + uint32_t BufferPtr = UINT32_C(160); + uint32_t NamePtr = UINT32_C(80); + uint32_t LinesizePtr = UINT32_C(20); + + uint32_t SampleFmtId = 1; // AV_SAMPLE_FMT_S32 + auto *FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_packed_sample_fmt"); + auto &HostFuncAVGetPackedSampleFmt = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetPackedSampleFmt &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetPackedSampleFmt.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), SampleFmtId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_planar_sample_fmt"); + auto &HostFuncAVGetPlanarSampleFmt = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetPlanarSampleFmt &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetPlanarSampleFmt.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), 6); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_sample_fmt_is_planar"); + auto &HostFuncAVSampleFmtIsPlanar = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVSampleFmtIsPlanar &>( + FuncInst->getHostFunc()); + + { + HostFuncAVSampleFmtIsPlanar.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_bytes_per_sample"); + auto &HostFuncAVGetBytesPerSample = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetBytesPerSample &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetBytesPerSample.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_get_sample_fmt"); + auto &HostFuncAVGetSampleFmt = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t SampleFmtStart = 100; + uint32_t SampleFmtSize = 2; + fillMemContent(MemInst, SampleFmtSize, SampleFmtSize); + + fillMemContent(MemInst, SampleFmtStart, "u8"sv); + { + HostFuncAVGetSampleFmt.run(CallFrame, + std::initializer_list{ + SampleFmtStart, SampleFmtSize}, + Result); + + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_samples_get_buffer_size"); + auto &HostFuncAVSamplesGetBufferSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVSamplesGetBufferSize &>( + FuncInst->getHostFunc()); + + int32_t NbChannels = 1; + int32_t NbSamples = 5; + int32_t Align = 1; + int32_t BufSize = 0; + { + HostFuncAVSamplesGetBufferSize.run( + CallFrame, + std::initializer_list{NbChannels, NbSamples, + SampleFmtId, Align}, + Result); + + BufSize = Result[0].get(); + EXPECT_TRUE(BufSize); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_samples_alloc_array_and_samples"); + auto &HostFuncAVSamplesAllocArrayAndSamples = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVSamplesAllocArrayAndSamples &>( + FuncInst->getHostFunc()); + + { + HostFuncAVSamplesAllocArrayAndSamples.run( + CallFrame, + std::initializer_list{ + BufferPtr, LinesizePtr, NbChannels, NbSamples, SampleFmtId, Align}, + Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + uint32_t BufId = readUInt32(MemInst, BufferPtr); + ASSERT_TRUE(BufId > 0); + + int32_t Length = 0; + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_sample_fmt_name_length"); + auto &HostFuncAVGetSampleFmtNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetSampleFmtNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetSampleFmtNameLength.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_sample_fmt_name"); + auto &HostFuncAVGetSampleFmtName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetSampleFmtName &>( + FuncInst->getHostFunc()); + + // Fill Memory with 0. + fillMemContent(MemInst, NamePtr, Length); + { + HostFuncAVGetSampleFmtName.run(CallFrame, + std::initializer_list{ + SampleFmtId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_sample_fmt_mask"); + auto &HostFuncAVGetSampleFmtMask = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetSampleFmtMask &>( + FuncInst->getHostFunc()); + + { + uint32_t SampleId = 2; // AV_SAMPLE_FMT_S16; + HostFuncAVGetSampleFmtMask.run( + CallFrame, std::initializer_list{SampleId}, + Result); + + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_freep"); + auto &HostFuncAVFreep = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t BufferId = readUInt32(MemInst, BufferPtr); + HostFuncAVFreep.run(CallFrame, + std::initializer_list{BufferId}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp new file mode 100644 index 00000000..2406cea1 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "avutil/avutil_func.h" +#include "avutil/avTime.h" +#include "avutil/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVUtilFunc) { + ASSERT_TRUE(AVUtilMod != nullptr); + + uint32_t NamePtr = UINT32_C(4); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_set_level"); + auto &HostFuncAVLogSetLevel = + dynamic_cast( + FuncInst->getHostFunc()); + + int32_t LogLvlId = 32; + { + HostFuncAVLogSetLevel.run( + CallFrame, std::initializer_list{LogLvlId}, + Result); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_get_level"); + auto &HostFuncAVLogGetLevel = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVLogGetLevel.run( + CallFrame, std::initializer_list{}, Result); + EXPECT_EQ(Result[0].get(), LogLvlId); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_set_flags"); + auto &HostFuncAVLogSetFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVLogSetFlags.run( + CallFrame, std::initializer_list{1}, Result); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_get_flags"); + auto &HostFuncAVLogGetFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVLogGetFlags.run( + CallFrame, std::initializer_list{1}, Result); + + EXPECT_EQ(Result[0].get(), 32); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_rescale_q"); + auto &HostFuncAVRescaleQ = + dynamic_cast( + FuncInst->getHostFunc()); + + int64_t A = 20; + int32_t BNum = 5; + int32_t BDen = 10; + int32_t CNum = 5; + int32_t CDen = 20; + + { + HostFuncAVRescaleQ.run( + CallFrame, + std::initializer_list{A, BNum, BDen, CNum, CDen}, + Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_rescale_q_rnd"); + auto &HostFuncAVRescaleQRnd = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t RoundingId = 2; + HostFuncAVRescaleQRnd.run(CallFrame, + std::initializer_list{ + A, BNum, BDen, CNum, CDen, RoundingId}, + Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_avutil_version"); + auto &HostFuncAVUtilVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilVersion.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + uint64_t ChannelId = 1; // FRONT_LEFT + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_nb_channels"); + auto &HostFuncAVGetChannelLayoutNbChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetChannelLayoutNbChannels &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetChannelLayoutNbChannels.run( + CallFrame, std::initializer_list{ChannelId}, + Result); + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_default_channel_layout"); + auto &HostFuncAVGetDefaultChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetDefaultChannelLayout &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetDefaultChannelLayout.run( + CallFrame, std::initializer_list{ChannelId}, + Result); + EXPECT_TRUE(Result[0].get() > 0); + } + + uint32_t Length = 0; + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avutil_configuration_length"); + auto &HostFuncAVUtilConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVUtilConfigurationLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilConfigurationLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_avutil_configuration"); + auto &HostFuncAVUtilConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVUtilConfiguration &>( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilConfiguration.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avutil_license_length"); + auto &HostFuncAVUtilLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVUtilLicenseLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilLicenseLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_avutil_license"); + auto &HostFuncAVUtilLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilLicense.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +TEST_F(FFmpegTest, AVTime) { + + ASSERT_TRUE(AVUtilMod != nullptr); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_gettime"); + auto &HostFuncAVGetTime = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVGetTime.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_gettime_relative"); + auto &HostFuncAVGetTimeRelative = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVGetTimeRelative.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_gettime_relative_is_monotonic"); + auto &HostFuncAVGetTimeRelativeIsMonotonic = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetTimeRelativeIsMonotonic &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetTimeRelativeIsMonotonic.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_usleep"); + auto &HostFuncAVUSleep = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUSleep.run( + CallFrame, std::initializer_list{1000}, Result); + + EXPECT_EQ(Result[0].get(), 0); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/main.cpp b/test/plugins/wasmedge_ffmpeg/main.cpp new file mode 100644 index 00000000..c2be683b --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/main.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include + +GTEST_API_ int main(int Argc, char **Argv) { + testing::InitGoogleTest(&Argc, Argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp new file mode 100644 index 00000000..6f84db6f --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "swresample/swresample_func.h" +#include "swresample/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, SWResampleFunc) { + ASSERT_TRUE(SWResampleMod != nullptr); + + uint32_t DictPtr = UINT32_C(4); + uint32_t SWResamplePtr = UINT32_C(8); + uint32_t FramePtr = UINT32_C(72); + uint32_t Frame2Ptr = UINT32_C(16); + uint32_t KeyPtr = UINT32_C(100); + uint32_t ValuePtr = UINT32_C(200); + + initDict(DictPtr, KeyPtr, std::string("Key"), ValuePtr, std::string("Value")); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), FramePtr); + + uint32_t StrPtr = UINT32_C(76); + initEmptyFrame(Frame2Ptr); + + uint32_t DictId = readUInt32(MemInst, DictPtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + uint32_t Frame2Id = readUInt32(MemInst, Frame2Ptr); + + auto *FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSWResampleVersion = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleVersion &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSWResampleVersion.run(CallFrame, {}, Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swr_alloc_set_opts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrAllocSetOpts = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWRAllocSetOpts &>( + FuncInst->getHostFunc()); + + // Testing with Null Old SwrCtx. Hence 2nd argument is 0. + { + uint32_t SWRCtxId = 0; + uint64_t OutChLayoutId = 1 << 1; // Front Right + uint32_t OutSampleFmtId = 2; // AV_SAMPLE_FMT_S16 + int32_t OutSampleRate = 30; + uint64_t InChLayoutId = 1 << 2; // FRONT_CENTER + uint32_t InSampleFmtId = 3; // AV_SAMPLE_FMT_S32 + int32_t SampleRate = 40; + int32_t LogOffset = 1; + + EXPECT_TRUE(HostFuncSwrAllocSetOpts.run( + CallFrame, + std::initializer_list{ + SWResamplePtr, SWRCtxId, OutChLayoutId, OutSampleFmtId, + OutSampleRate, InChLayoutId, InSampleFmtId, SampleRate, LogOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWResamplePtr) > 0); + } + + // Test with Existing SwrCtx. + uint32_t SwrId = readUInt32(MemInst, SWResamplePtr); + { + uint32_t SWRCtxId = SwrId; + uint64_t OutChLayoutId = 1 << 1; // Front Right + uint32_t OutSampleFmtId = 2; // AV_SAMPLE_FMT_S16 + int32_t OutSampleRate = 30; + uint64_t InChLayoutId = 1 << 2; // FRONT_CENTER + uint32_t InSampleFmtId = 3; // AV_SAMPLE_FMT_S32 + int32_t SampleRate = 40; + int32_t LogOffset = 1; + EXPECT_TRUE(HostFuncSwrAllocSetOpts.run( + CallFrame, + std::initializer_list{ + SWResamplePtr, SWRCtxId, OutChLayoutId, OutSampleFmtId, + OutSampleRate, InChLayoutId, InSampleFmtId, SampleRate, LogOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWResamplePtr) > 0); + } + + FuncInst = + SWResampleMod->findFuncExports("wasmedge_ffmpeg_swresample_swr_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrFree = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwrFree.run( + CallFrame, std::initializer_list{SwrId}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + SWResampleMod->findFuncExports("wasmedge_ffmpeg_swresample_swr_init"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrInit = + dynamic_cast( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrInit.run( + CallFrame, std::initializer_list{SwrId}, Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_av_opt_set_dict"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVOptSetDict = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t EmptyDictId = 0; + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncAVOptSetDict.run( + CallFrame, + std::initializer_list{SwrId, EmptyDictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncAVOptSetDict.run( + CallFrame, std::initializer_list{SwrId, DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swr_convert_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrConvertFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWRConvertFrame &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrConvertFrame.run( + CallFrame, + std::initializer_list{SwrId, Frame2Id, FrameId}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swr_get_delay"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrGetDelay = + dynamic_cast( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrGetDelay.run( + CallFrame, std::initializer_list{SwrId, 1}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrConfigLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleConfigurationLength + &>(FuncInst->getHostFunc()); + + int32_t Length = 0; + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrConfigLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrConfig = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleConfiguration &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrConfig.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrLicenseLen = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleLicenseLength &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrLicenseLen.run( + CallFrame, std::initializer_list{}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrLicense = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleLicense &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp new file mode 100644 index 00000000..805b2224 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -0,0 +1,541 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "swscale/swscale_func.h" +#include "swscale/module.h" + +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// ============================================================================ +// This test deals with functions related to SwsContext. +// ============================================================================ + +TEST_F(FFmpegTest, SwsContext) { + ASSERT_TRUE(SWScaleMod != nullptr); + + auto *FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getContext"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetContext = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t SWScalePtr = UINT32_C(4); + uint32_t SWCachedScalePtr = UINT32_C(8); + uint32_t FramePtr = UINT32_C(72); + uint32_t Frame2Ptr = UINT32_C(124); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(12), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), FramePtr); + + initEmptyFrame(Frame2Ptr); + + uint32_t FrameId = readUInt32(MemInst, FramePtr); + uint32_t Frame2Id = readUInt32(MemInst, Frame2Ptr); + + uint32_t YUV420PId = 1; // YUV420P AVPixFormatId (From Bindings.h) + uint32_t RGB24Id = 3; // RGB24 AVPixFormatId (From Bindings.h) + uint32_t XVMCId = 174; // XVMC AVPixFormatId (From Bindings.h) + + uint32_t SrcWidth = 100; + uint32_t SrcHeight = 100; + uint32_t DestWidth = 200; + uint32_t DestHeight = 200; + int32_t Flags = 8; + uint32_t SrcFilterId = 0; + uint32_t DestFilterId = 0; + + // Allocating SWScale... + // Filter ID for source and destination is Null. + { + EXPECT_TRUE(HostFuncSwsGetContext.run( + CallFrame, + std::initializer_list{ + SWScalePtr, SrcWidth, SrcHeight, YUV420PId, DestWidth, DestHeight, + RGB24Id, Flags, SrcFilterId, DestFilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWScalePtr) > 0); + } + + uint32_t SWSScaleId = readUInt32(MemInst, SWScalePtr); + ASSERT_TRUE(SWSScaleId > 0); + + // Checking correctness of function. Returns Invalid Argument Error. + FuncInst = SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_scale"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsScale = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE( + HostFuncSwsScale.run(CallFrame, + std::initializer_list{ + SWSScaleId, FrameId, 20, 40, Frame2Id}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_getCachedContext"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetCachedContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetCachedContext &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsGetCachedContext.run( + CallFrame, + std::initializer_list{ + SWCachedScalePtr, SWSScaleId, SrcWidth, SrcHeight, YUV420PId, + DestWidth, DestHeight, RGB24Id, Flags, SrcFilterId, DestFilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWCachedScalePtr) > 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_isSupportedInput"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsIsSupportedInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsIsSupportedInput &>( + FuncInst->getHostFunc()); + + { + // AV_PIX_FMT_RGB24 is a supported pixel format. + EXPECT_TRUE(HostFuncSwsIsSupportedInput.run( + CallFrame, std::initializer_list{RGB24Id}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + + // AV_PIX_FMT_XVMC is not a supported pixel format. + EXPECT_TRUE(HostFuncSwsIsSupportedInput.run( + CallFrame, std::initializer_list{XVMCId}, + Result)); + ASSERT_TRUE(Result[0].get() == 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_isSupportedOutput"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsIsSupportedOutput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsIsSupportedOutput &>( + FuncInst->getHostFunc()); + + { + // AV_PIX_FMT_RGB24 is a supported pixel format. + EXPECT_TRUE(HostFuncSwsIsSupportedOutput.run( + CallFrame, std::initializer_list{RGB24Id}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + + // AV_PIX_FMT_XVMC is not a supported pixel format. + EXPECT_TRUE(HostFuncSwsIsSupportedOutput.run( + CallFrame, std::initializer_list{XVMCId}, + Result)); + ASSERT_TRUE(Result[0].get() == 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_isSupportedEndiannessConversion"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsIsSupportedEndiannessConversion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // AV_PIX_FMT_XVMC is not a supported pixel format for + EXPECT_TRUE(HostFuncSwsIsSupportedEndiannessConversion.run( + CallFrame, std::initializer_list{XVMCId}, + Result)); + ASSERT_TRUE(Result[0].get() == 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_freeContext"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsFreeContext = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsFreeContext.run( + CallFrame, std::initializer_list{SWSScaleId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + { + uint32_t InvalidDestWidth = -200; + uint32_t InvalidDestHeight = -200; + uint32_t SWScalePtrInvalid = UINT32_C(80); + EXPECT_TRUE(HostFuncSwsGetContext.run( + CallFrame, + std::initializer_list{ + SWScalePtrInvalid, SrcWidth, SrcHeight, YUV420PId, InvalidDestWidth, + InvalidDestHeight, RGB24Id, Flags, SrcFilterId, DestFilterId}, + Result)); + EXPECT_EQ(Result[0].get(), + static_cast(ErrNo::InternalError)); + ASSERT_TRUE(readUInt32(MemInst, SWScalePtrInvalid) == 0); + } +} + +// ============================================================================ +// This test deals with functions related to SwsFilter. +// ============================================================================ + +TEST_F(FFmpegTest, SwsFilter) { + ASSERT_TRUE(SWScaleMod != nullptr); + auto *FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_getDefaultFilter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetDefaultFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetDefaultFilter &>( + FuncInst->getHostFunc()); + + uint32_t SwsFilterPtr = UINT32_C(40); + { + float LumaGBlur = 10.5; + float ChromaGBlur = 10.5; + float LumaSharpen = 10.5; + float ChromaSharpen = 10.5; + float ChromaHShift = 10.5; + float ChromaVShift = 10.5; + int32_t Verbose = 1; + + EXPECT_TRUE(HostFuncSwsGetDefaultFilter.run( + CallFrame, + std::initializer_list{ + SwsFilterPtr, LumaGBlur, ChromaGBlur, LumaSharpen, ChromaSharpen, + ChromaHShift, ChromaVShift, Verbose}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsFilterPtr) > 0); + } + + uint32_t FilterId = readUInt32(MemInst, SwsFilterPtr); + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getLumaH"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetLumaH = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t SwsVectorPtr = UINT32_C(20); + { + EXPECT_TRUE(HostFuncSwsGetLumaH.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getLumaV"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetLumaV = + dynamic_cast( + FuncInst->getHostFunc()); + + { + writeUInt32(MemInst, UINT32_C(0), SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsGetLumaV.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getChromaH"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetChromaH = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsGetChromaH.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getChromaV"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetChromaV = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsGetChromaV.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_freeFilter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsFreeFilter = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsFreeFilter.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +// ============================================================================ +// This test deals with functions related to SwsVector. +// ============================================================================ + +TEST_F(FFmpegTest, SwsVector) { + ASSERT_TRUE(SWScaleMod != nullptr); + uint32_t SwsVectorPtr = UINT32_C(40); + uint32_t CoeffPtr = UINT32_C(100); + + auto *FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_allocVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsAllocVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + writeUInt32(MemInst, UINT32_C(0), SwsVectorPtr); + int32_t Length = 20; + EXPECT_TRUE(HostFuncSwsAllocVec.run( + CallFrame, + std::initializer_list{SwsVectorPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getGaussianVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetGaussianVec = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetGaussianVec &>( + FuncInst->getHostFunc()); + + { + writeUInt32(MemInst, UINT32_C(0), SwsVectorPtr); + double Variance = 20.5; + double Quality = 4.3; + EXPECT_TRUE(HostFuncSwsGetGaussianVec.run( + CallFrame, + std::initializer_list{SwsVectorPtr, Variance, + Quality}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_scaleVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsScaleVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + double Scalar = 20.35; + EXPECT_TRUE(HostFuncSwsScaleVec.run( + CallFrame, + std::initializer_list{SwsVecId, Scalar}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_normalizeVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsNormalizeVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + double Height = 4.3; + EXPECT_TRUE(HostFuncSwsNormalizeVec.run( + CallFrame, + std::initializer_list{SwsVecId, Height}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_getCoeffVecLength"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetCoeffVecLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetCoeffVecLength &>( + FuncInst->getHostFunc()); + + int Length = 0; + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsGetCoeffVecLength.run( + CallFrame, std::initializer_list{SwsVecId}, + Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getCoeff"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetCoeff = + dynamic_cast( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, CoeffPtr, Length); + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsGetCoeff.run( + CallFrame, + std::initializer_list{SwsVecId, CoeffPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_freeVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsFreeVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsFreeVec.run( + CallFrame, std::initializer_list{SwsVecId}, + Result)); + } +} + +// ============================================================================ +// This test deals with functions related to Version, Configuration, and +// License. +// ============================================================================ + +TEST_F(FFmpegTest, SWScaleVersion) { + ASSERT_TRUE(SWScaleMod != nullptr); + + uint32_t Length = 0; + uint32_t NamePtr = UINT32_C(8); + + auto *FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_swscale_version"); + auto &HostFuncSwscaleVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleVersion.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_swscale_configuration_length"); + auto &HostFuncSwscaleConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwscaleConfigurationLength &>( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleConfigurationLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Testing Version, Configuration, License + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_swscale_configuration"); + auto &HostFuncSwscaleConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwscaleConfiguration &>( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleConfiguration.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_swscale_license_length"); + auto &HostFuncSwscaleLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwscaleLicenseLength &>( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleLicenseLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_swscale_license"); + auto &HostFuncSwscaleLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleLicense.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/utils.cpp b/test/plugins/wasmedge_ffmpeg/utils.cpp new file mode 100644 index 00000000..2656cbb3 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/utils.cpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "utils.h" + +#include "avcodec/avCodecContext.h" +#include "avcodec/avPacket.h" +#include "avcodec/avcodec_func.h" +#include "avformat/avStream.h" +#include "avformat/avformat_func.h" +#include "avutil/avDictionary.h" +#include "avutil/avFrame.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +void FFmpegTest::initEmptyFrame(uint32_t FramePtr) { + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_alloc"); + auto &HostFuncAVFrameAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + HostFuncAVFrameAlloc.run( + CallFrame, std::initializer_list{FramePtr}, Result); +} + +void FFmpegTest::initFFmpegStructs(uint32_t AVCodecPtr, uint32_t AVFormatCtxPtr, + uint32_t FilePtr, std::string FileName, + uint32_t CodecParameterPtr, + uint32_t AVCodecCtxPtr, uint32_t PacketPtr, + uint32_t FramePtr) { + initFormatCtx(AVFormatCtxPtr, FilePtr, FileName); + + uint32_t AvFormatCtxId = readUInt32(MemInst, AVFormatCtxPtr); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_av_find_best_stream"); + auto &HostFuncAVFindBestStream = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFindBestStream &>( + FuncInst->getHostFunc()); + uint32_t MediaTypeId = 0; // Video + uint32_t WantedStream = -1; + uint32_t RelatedStream = -1; + uint32_t DecoderRetId = 0; + uint32_t Flags = 0; + HostFuncAVFindBestStream.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, MediaTypeId, WantedStream, + RelatedStream, DecoderRetId, Flags}, + Result); + + uint32_t StreamIdx = Result[0].get(); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_codecpar"); + + auto &HostFuncAVStreamCodecpar = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamCodecPar &>( + FuncInst->getHostFunc()); + + HostFuncAVStreamCodecpar.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, StreamIdx, CodecParameterPtr}, + Result); + + uint32_t CodecParametersId = readUInt32(MemInst, CodecParameterPtr); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_alloc_context3"); + auto &HostFuncAVCodecAllocContext3 = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecAllocContext3 &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecAllocContext3.run( + CallFrame, std::initializer_list{0, AVCodecCtxPtr}, + Result); + + uint32_t AVCodecCtxId = readUInt32(MemInst, AVCodecCtxPtr); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_to_context"); + auto &HostFuncAVCodecParametersToContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersToContext &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecParametersToContext.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + CodecParametersId}, + Result); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec_id"); + auto &HostFuncAVCodecContextCodecId = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxCodecID &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecContextCodecId.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result); + + uint32_t CodecId = Result[0].get(); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_decoder"); + auto &HostFuncAVCodecFindDecoder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindDecoder &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecFindDecoder.run( + CallFrame, + std::initializer_list{CodecId, AVCodecPtr}, Result); + + uint32_t AVCodecId = readUInt32(MemInst, AVCodecPtr); + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_open2"); + auto &HostFuncAVCodecOpen2 = + dynamic_cast( + FuncInst->getHostFunc()); + + HostFuncAVCodecOpen2.run( + CallFrame, + std::initializer_list{AVCodecCtxId, AVCodecId, 0}, + Result); + + initEmptyFrame(FramePtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_receive_frame"); + auto &HostFuncAVCodecReceiveFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecReceiveFrame &>( + FuncInst->getHostFunc()); + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_frame"); + auto &HostFuncAVReadFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_send_packet"); + auto &HostFuncAVCodecSendPacket = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSendPacket &>( + FuncInst->getHostFunc()); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_stream_index"); + auto &HostFuncAVPacketStreamIndex = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketStreamIndex &>( + FuncInst->getHostFunc()); + + while (true) { + HostFuncAVCodecReceiveFrame.run( + CallFrame, + std::initializer_list{AVCodecCtxId, FrameId}, + Result); + + // Error returned by FFmpeg are negative. + int32_t Error = Result[0].get() * (-1); + + if (Error == EAGAIN) { + while (true) { + allocPacket(PacketPtr); + + uint32_t PackedId = readUInt32(MemInst, PacketPtr); + + while (true) { + HostFuncAVReadFrame.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, PackedId}, + Result); + + int32_t Res = Result[0].get(); + if (Res == 0 || Res == AVERROR_EOF) { + break; + } + } + + HostFuncAVPacketStreamIndex.run( + CallFrame, + std::initializer_list{AVCodecCtxId, FrameId}, + Result); + + uint32_t PacketStreamIdx = Result[0].get(); + + if (PacketStreamIdx != StreamIdx) { + continue; + } + + HostFuncAVCodecSendPacket.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PackedId}, + Result); + break; + } + } else { + break; + } + } +} + +void FFmpegTest::initFormatCtx(uint32_t AVFormatCtxPtr, uint32_t FilePtr, + std::string FileName) { + int32_t Length = FileName.length(); + fillMemContent(MemInst, FilePtr, Length); + fillMemContent(MemInst, FilePtr, FileName); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_open_input"); + auto &HostFuncAVFormatOpenInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatOpenInput &>( + FuncInst->getHostFunc()); + HostFuncAVFormatOpenInput.run( + CallFrame, + std::initializer_list{ + AVFormatCtxPtr, FilePtr, Length, UINT32_C(0), UINT32_C(0)}, + Result); +} + +void FFmpegTest::initDict(uint32_t DictPtr, uint32_t KeyPtr, std::string Key, + uint32_t ValuePtr, std::string Value) { + uint32_t KeyLen = Key.length(); + uint32_t ValueLen = Value.length(); + fillMemContent(MemInst, KeyPtr, KeyLen + ValueLen); + fillMemContent(MemInst, KeyPtr, Key); + fillMemContent(MemInst, ValuePtr, Value); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_set"); + auto &HostFuncAVDictSet = + dynamic_cast( + FuncInst->getHostFunc()); + + HostFuncAVDictSet.run(CallFrame, + std::initializer_list{ + DictPtr, KeyPtr, KeyLen, ValuePtr, ValueLen, 0}, + Result); +} + +void FFmpegTest::allocPacket(uint32_t PacketPtr) { + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_alloc"); + auto &HostFuncAVPacketAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + + HostFuncAVPacketAlloc.run( + CallFrame, std::initializer_list{PacketPtr}, + Result); +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/utils.h b/test/plugins/wasmedge_ffmpeg/utils.h new file mode 100644 index 00000000..79b00242 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/utils.h @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "avcodec/module.h" +#include "avfilter/module.h" +#include "avformat/module.h" +#include "avutil/module.h" +#include "swresample/module.h" +#include "swscale/module.h" + +#include "common/types.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +inline void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst->getPointer(Ptr); + *BufPtr = Value; +} + +inline void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t Offset, uint32_t Cnt, + uint8_t C = 0) noexcept { + std::fill_n(MemInst->getPointer(Offset), Cnt, C); +} + +inline void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t Offset, std::string_view Str) noexcept { + char *Buf = MemInst->getPointer(Offset); + std::copy_n(Str.data(), Str.length(), Buf); +} + +inline void writeSInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + int32_t Value, uint32_t &Ptr) { + int32_t *BufPtr = MemInst->getPointer(Ptr); + *BufPtr = Value; +} + +inline int32_t readSInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t &Ptr) { + int32_t *BufPtr = MemInst->getPointer(Ptr); + return *BufPtr; +} + +inline uint32_t readUInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t &Ptr) { + uint32_t *BufPtr = MemInst->getPointer(Ptr); + return *BufPtr; +} + +class FFmpegTest : public ::testing::Test { +public: + FFmpegTest() : Mod(""), CallFrame(nullptr, &Mod) { + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + MemInst = Mod.findMemoryExports("memory"); + + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_ffmpeg/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeFFmpeg" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_ffmpeg"sv)) { + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_avformat"sv)) { + AVFormatMod = + dynamicPointerCast( + Module->create()); + } + if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_avutil"sv)) { + AVUtilMod = dynamicPointerCast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule>( + Module->create()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_swscale"sv)) { + SWScaleMod = + dynamicPointerCast( + Module->create()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_avcodec"sv)) { + AVCodecMod = + dynamicPointerCast( + Module->create()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_swresample"sv)) { + SWResampleMod = + dynamicPointerCast( + Module->create()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_avfilter"sv)) { + AVFilterMod = + dynamicPointerCast( + Module->create()); + } + } + } + +protected: + void initEmptyFrame(uint32_t FramePtr); + + void initDict(uint32_t DictPtr, uint32_t KeyPtr, std::string Key, + uint32_t ValuePtr, std::string Value); + void initFFmpegStructs(uint32_t AVCodecPtr, uint32_t AVFormatCtxPtr, + uint32_t FilePtr, std::string FileName, + uint32_t CodecParameterPtr, uint32_t AVCodecCtxPtr, + uint32_t PacketPtr, uint32_t FramePtr); + + void initFormatCtx(uint32_t AVFormatCtxPtr, uint32_t FilePtr, + std::string FileName); + void allocPacket(uint32_t PacketPtr); + + // Results of Funcs are stored here. + std::array Result = {UINT32_C(0)}; + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod; + WasmEdge::Runtime::Instance::MemoryInstance *MemInst; + WasmEdge::Runtime::CallingFrame CallFrame; + + // Wasm Modules. + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::WasmEdgeFFmpegAVFormatModule> + AVFormatMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule> + AVUtilMod; + std::unique_ptr + SWResampleMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule> + SWScaleMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule> + AVCodecMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::WasmEdgeFFmpegAVFilterModule> + AVFilterMod; +}; + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_image/CMakeLists.txt b/test/plugins/wasmedge_image/CMakeLists.txt new file mode 100644 index 00000000..62e1bd25 --- /dev/null +++ b/test/plugins/wasmedge_image/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeImageTests + wasmedge_image.cpp +) + +add_dependencies(wasmedgeImageTests + wasmedgePluginWasmEdgeImage +) + +target_include_directories(wasmedgeImageTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeImageTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeImageTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeImageTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeImageTests wasmedgeImageTests) diff --git a/test/plugins/wasmedge_image/wasmedge_image.cpp b/test/plugins/wasmedge_image/wasmedge_image.cpp new file mode 100644 index 00000000..479afbd0 --- /dev/null +++ b/test/plugins/wasmedge_image/wasmedge_image.cpp @@ -0,0 +1,517 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "image_module.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include +#include + +using WasmEdge::Host::WasmEdgeImage::ErrNo; + +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_image/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeImage" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_image"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_image"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, + const std::vector &Payload) noexcept { + uint8_t *Buf = MemInst.getPointer(Offset); + std::copy_n(Payload.data(), Payload.size(), Buf); +} + +// Test image: #FF0000 png file, 30x30 image size, 158 bytes. +std::vector TestRedPNG = { + 0x89U, 0x50U, 0x4EU, 0x47U, 0x0DU, 0x0AU, 0x1AU, 0x0AU, 0x00U, 0x00U, 0x00U, + 0x0DU, 0x49U, 0x48U, 0x44U, 0x52U, 0x00U, 0x00U, 0x00U, 0x1EU, 0x00U, 0x00U, + 0x00U, 0x1EU, 0x08U, 0x06U, 0x00U, 0x00U, 0x00U, 0x3BU, 0x30U, 0xAEU, 0xA2U, + 0x00U, 0x00U, 0x00U, 0x01U, 0x73U, 0x52U, 0x47U, 0x42U, 0x00U, 0xAEU, 0xCEU, + 0x1CU, 0xE9U, 0x00U, 0x00U, 0x00U, 0x04U, 0x67U, 0x41U, 0x4DU, 0x41U, 0x00U, + 0x00U, 0xB1U, 0x8FU, 0x0BU, 0xFCU, 0x61U, 0x05U, 0x00U, 0x00U, 0x00U, 0x09U, + 0x70U, 0x48U, 0x59U, 0x73U, 0x00U, 0x00U, 0x16U, 0x25U, 0x00U, 0x00U, 0x16U, + 0x25U, 0x01U, 0x49U, 0x52U, 0x24U, 0xF0U, 0x00U, 0x00U, 0x00U, 0x33U, 0x49U, + 0x44U, 0x41U, 0x54U, 0x48U, 0x4BU, 0xEDU, 0xCDU, 0xA1U, 0x01U, 0x00U, 0x00U, + 0x0CU, 0x83U, 0xB0U, 0xFEU, 0xFFU, 0x74U, 0xE7U, 0x77U, 0x00U, 0x35U, 0x88U, + 0x18U, 0x0CU, 0x69U, 0xD2U, 0x85U, 0xFCU, 0x40U, 0x71U, 0x8CU, 0x71U, 0x8CU, + 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, + 0x8CU, 0x71U, 0x8CU, 0x99U, 0x8DU, 0x0FU, 0xD5U, 0x6CU, 0x01U, 0x62U, 0x5DU, + 0xE8U, 0xB5U, 0x3DU, 0x00U, 0x00U, 0x00U, 0x00U, 0x49U, 0x45U, 0x4EU, 0x44U, + 0xAEU, 0x42U, 0x60U, 0x82U}; + +// Test image: #FF0000 jpg file, 30x30 image size, 647 bytes. +std::vector TestRedJPG = { + 0xFFU, 0xD8U, 0xFFU, 0xE0U, 0x00U, 0x10U, 0x4AU, 0x46U, 0x49U, 0x46U, 0x00U, + 0x01U, 0x01U, 0x01U, 0x00U, 0x90U, 0x00U, 0x90U, 0x00U, 0x00U, 0xFFU, 0xDBU, + 0x00U, 0x43U, 0x00U, 0x02U, 0x01U, 0x01U, 0x02U, 0x01U, 0x01U, 0x02U, 0x02U, + 0x02U, 0x02U, 0x02U, 0x02U, 0x02U, 0x02U, 0x03U, 0x05U, 0x03U, 0x03U, 0x03U, + 0x03U, 0x03U, 0x06U, 0x04U, 0x04U, 0x03U, 0x05U, 0x07U, 0x06U, 0x07U, 0x07U, + 0x07U, 0x06U, 0x07U, 0x07U, 0x08U, 0x09U, 0x0BU, 0x09U, 0x08U, 0x08U, 0x0AU, + 0x08U, 0x07U, 0x07U, 0x0AU, 0x0DU, 0x0AU, 0x0AU, 0x0BU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x07U, 0x09U, 0x0EU, 0x0FU, 0x0DU, 0x0CU, 0x0EU, 0x0BU, 0x0CU, 0x0CU, + 0x0CU, 0xFFU, 0xDBU, 0x00U, 0x43U, 0x01U, 0x02U, 0x02U, 0x02U, 0x03U, 0x03U, + 0x03U, 0x06U, 0x03U, 0x03U, 0x06U, 0x0CU, 0x08U, 0x07U, 0x08U, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0xFFU, 0xC0U, 0x00U, 0x11U, 0x08U, 0x00U, 0x1EU, + 0x00U, 0x1EU, 0x03U, 0x01U, 0x22U, 0x00U, 0x02U, 0x11U, 0x01U, 0x03U, 0x11U, + 0x01U, 0xFFU, 0xC4U, 0x00U, 0x1FU, 0x00U, 0x00U, 0x01U, 0x05U, 0x01U, 0x01U, + 0x01U, 0x01U, 0x01U, 0x01U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, + 0x00U, 0x01U, 0x02U, 0x03U, 0x04U, 0x05U, 0x06U, 0x07U, 0x08U, 0x09U, 0x0AU, + 0x0BU, 0xFFU, 0xC4U, 0x00U, 0xB5U, 0x10U, 0x00U, 0x02U, 0x01U, 0x03U, 0x03U, + 0x02U, 0x04U, 0x03U, 0x05U, 0x05U, 0x04U, 0x04U, 0x00U, 0x00U, 0x01U, 0x7DU, + 0x01U, 0x02U, 0x03U, 0x00U, 0x04U, 0x11U, 0x05U, 0x12U, 0x21U, 0x31U, 0x41U, + 0x06U, 0x13U, 0x51U, 0x61U, 0x07U, 0x22U, 0x71U, 0x14U, 0x32U, 0x81U, 0x91U, + 0xA1U, 0x08U, 0x23U, 0x42U, 0xB1U, 0xC1U, 0x15U, 0x52U, 0xD1U, 0xF0U, 0x24U, + 0x33U, 0x62U, 0x72U, 0x82U, 0x09U, 0x0AU, 0x16U, 0x17U, 0x18U, 0x19U, 0x1AU, + 0x25U, 0x26U, 0x27U, 0x28U, 0x29U, 0x2AU, 0x34U, 0x35U, 0x36U, 0x37U, 0x38U, + 0x39U, 0x3AU, 0x43U, 0x44U, 0x45U, 0x46U, 0x47U, 0x48U, 0x49U, 0x4AU, 0x53U, + 0x54U, 0x55U, 0x56U, 0x57U, 0x58U, 0x59U, 0x5AU, 0x63U, 0x64U, 0x65U, 0x66U, + 0x67U, 0x68U, 0x69U, 0x6AU, 0x73U, 0x74U, 0x75U, 0x76U, 0x77U, 0x78U, 0x79U, + 0x7AU, 0x83U, 0x84U, 0x85U, 0x86U, 0x87U, 0x88U, 0x89U, 0x8AU, 0x92U, 0x93U, + 0x94U, 0x95U, 0x96U, 0x97U, 0x98U, 0x99U, 0x9AU, 0xA2U, 0xA3U, 0xA4U, 0xA5U, + 0xA6U, 0xA7U, 0xA8U, 0xA9U, 0xAAU, 0xB2U, 0xB3U, 0xB4U, 0xB5U, 0xB6U, 0xB7U, + 0xB8U, 0xB9U, 0xBAU, 0xC2U, 0xC3U, 0xC4U, 0xC5U, 0xC6U, 0xC7U, 0xC8U, 0xC9U, + 0xCAU, 0xD2U, 0xD3U, 0xD4U, 0xD5U, 0xD6U, 0xD7U, 0xD8U, 0xD9U, 0xDAU, 0xE1U, + 0xE2U, 0xE3U, 0xE4U, 0xE5U, 0xE6U, 0xE7U, 0xE8U, 0xE9U, 0xEAU, 0xF1U, 0xF2U, + 0xF3U, 0xF4U, 0xF5U, 0xF6U, 0xF7U, 0xF8U, 0xF9U, 0xFAU, 0xFFU, 0xC4U, 0x00U, + 0x1FU, 0x01U, 0x00U, 0x03U, 0x01U, 0x01U, 0x01U, 0x01U, 0x01U, 0x01U, 0x01U, + 0x01U, 0x01U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x01U, 0x02U, 0x03U, + 0x04U, 0x05U, 0x06U, 0x07U, 0x08U, 0x09U, 0x0AU, 0x0BU, 0xFFU, 0xC4U, 0x00U, + 0xB5U, 0x11U, 0x00U, 0x02U, 0x01U, 0x02U, 0x04U, 0x04U, 0x03U, 0x04U, 0x07U, + 0x05U, 0x04U, 0x04U, 0x00U, 0x01U, 0x02U, 0x77U, 0x00U, 0x01U, 0x02U, 0x03U, + 0x11U, 0x04U, 0x05U, 0x21U, 0x31U, 0x06U, 0x12U, 0x41U, 0x51U, 0x07U, 0x61U, + 0x71U, 0x13U, 0x22U, 0x32U, 0x81U, 0x08U, 0x14U, 0x42U, 0x91U, 0xA1U, 0xB1U, + 0xC1U, 0x09U, 0x23U, 0x33U, 0x52U, 0xF0U, 0x15U, 0x62U, 0x72U, 0xD1U, 0x0AU, + 0x16U, 0x24U, 0x34U, 0xE1U, 0x25U, 0xF1U, 0x17U, 0x18U, 0x19U, 0x1AU, 0x26U, + 0x27U, 0x28U, 0x29U, 0x2AU, 0x35U, 0x36U, 0x37U, 0x38U, 0x39U, 0x3AU, 0x43U, + 0x44U, 0x45U, 0x46U, 0x47U, 0x48U, 0x49U, 0x4AU, 0x53U, 0x54U, 0x55U, 0x56U, + 0x57U, 0x58U, 0x59U, 0x5AU, 0x63U, 0x64U, 0x65U, 0x66U, 0x67U, 0x68U, 0x69U, + 0x6AU, 0x73U, 0x74U, 0x75U, 0x76U, 0x77U, 0x78U, 0x79U, 0x7AU, 0x82U, 0x83U, + 0x84U, 0x85U, 0x86U, 0x87U, 0x88U, 0x89U, 0x8AU, 0x92U, 0x93U, 0x94U, 0x95U, + 0x96U, 0x97U, 0x98U, 0x99U, 0x9AU, 0xA2U, 0xA3U, 0xA4U, 0xA5U, 0xA6U, 0xA7U, + 0xA8U, 0xA9U, 0xAAU, 0xB2U, 0xB3U, 0xB4U, 0xB5U, 0xB6U, 0xB7U, 0xB8U, 0xB9U, + 0xBAU, 0xC2U, 0xC3U, 0xC4U, 0xC5U, 0xC6U, 0xC7U, 0xC8U, 0xC9U, 0xCAU, 0xD2U, + 0xD3U, 0xD4U, 0xD5U, 0xD6U, 0xD7U, 0xD8U, 0xD9U, 0xDAU, 0xE2U, 0xE3U, 0xE4U, + 0xE5U, 0xE6U, 0xE7U, 0xE8U, 0xE9U, 0xEAU, 0xF2U, 0xF3U, 0xF4U, 0xF5U, 0xF6U, + 0xF7U, 0xF8U, 0xF9U, 0xFAU, 0xFFU, 0xDAU, 0x00U, 0x0CU, 0x03U, 0x01U, 0x00U, + 0x02U, 0x11U, 0x03U, 0x11U, 0x00U, 0x3FU, 0x00U, 0xF8U, 0xBEU, 0x8AU, 0x28U, + 0xAFU, 0xE5U, 0x33U, 0xFDU, 0xFCU, 0x0AU, 0x28U, 0xA2U, 0x80U, 0x0AU, 0x28U, + 0xA2U, 0x80U, 0x0AU, 0x28U, 0xA2U, 0x80U, 0x3FU, 0xFFU, 0xD9U}; + +} // namespace + +TEST(WasmEdgeImageTest, Module) { + // Create the wasmedge_image module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + EXPECT_EQ(ImgMod->getFuncExportNum(), 3U); + EXPECT_NE(ImgMod->findFuncExports("load_jpg"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("load_png"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("load_image"), nullptr); +} + +TEST(WasmEdgeImageTest, LoadJPG) { + // Create the wasmedge_image module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array Errno = {UINT32_C(0)}; + + // Set the target size. + uint32_t TargetW = 50, TargetH = 60; + uint32_t TargetSize = TargetW * TargetH * 3; + // Assume the pixel position (45, 55). + uint32_t Position = 55 * TargetW + 45; + // Input payload offset. + uint32_t InOffset = 0; + // Output image data offset. + uint32_t OutOffset = 1024; + // Output image span. + WasmEdge::Span OutSpanU8; + WasmEdge::Span OutSpanF32; + + // Get the function "load_jpg". + auto *FuncInst = ImgMod->findFuncExports("load_jpg"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = FuncInst->getHostFunc(); + + // Test: Load JPG and resize into 50x60 RGB u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 0U, // Target type: RGB8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 254, not 255 here. + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(254)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(0)); + + // Test: Load JPG and resize into 50x60 BGR u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 1U, // Target type: BGR8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 254, not 255 here. + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(254)); + + // Test: Load JPG and resize into 50x60 RGB f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 2U, // Target type: RGB32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 0.991392851f, not 1.0f here. + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 1.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 0.0f) < 0.01f); + + // Test: Load JPG and resize into 50x60 BGR f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 3U, // Target type: BGR32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 0.991392851f, not 1.0f here. + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 0.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 1.0f) < 0.01f); +} + +TEST(WasmEdgeImageTest, LoadPNG) { + // Create the wasmedge_image module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array Errno = {UINT32_C(0)}; + + // Set the target size. + uint32_t TargetW = 50, TargetH = 60; + uint32_t TargetSize = TargetW * TargetH * 3; + // Assume the pixel position (45, 55). + uint32_t Position = 55 * TargetW + 45; + // Input payload offset. + uint32_t InOffset = 0; + // Output image data offset. + uint32_t OutOffset = 1024; + // Output image span. + WasmEdge::Span OutSpanU8; + WasmEdge::Span OutSpanF32; + + // Get the function "load_png". + auto *FuncInst = ImgMod->findFuncExports("load_png"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = FuncInst->getHostFunc(); + + // Test: Load PNG and resize into 50x60 RGB u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 0U, // Target type: RGB8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(255)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(0)); + + // Test: Load PNG and resize into 50x60 BGR u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 1U, // Target type: BGR8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(255)); + + // Test: Load PNG and resize into 50x60 RGB f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 2U, // Target type: RGB32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 1.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 0.0f) < 0.00001f); + + // Test: Load PNG and resize into 50x60 BGR f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 3U, // Target type: BGR32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 1.0f) < 0.00001f); +} + +TEST(WasmEdgeImageTest, LoadImage) { + // Test for the general API. + + // Create the wasmedge_image module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array Errno = {UINT32_C(0)}; + + // Set the target size. + uint32_t TargetW = 50, TargetH = 60; + uint32_t TargetSize = TargetW * TargetH * 3; + // Assume the pixel position (45, 55). + uint32_t Position = 55 * TargetW + 45; + // Input payload offset. + uint32_t InOffset = 0; + // Output image data offset. + uint32_t OutOffset = 1024; + // Output image span. + WasmEdge::Span OutSpanU8; + WasmEdge::Span OutSpanF32; + + // Get the function "load_image". + auto *FuncInst = ImgMod->findFuncExports("load_image"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = FuncInst->getHostFunc(); + + // Test: Load JPG and resize into 50x60 BGR u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 1U, // Target type: BGR8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 254, not 255 here. + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(254)); + + // Test: Load PNG and resize into 50x60 RGB f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 2U, // Target type: RGB32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 1.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 0.0f) < 0.00001f); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_llmc/CMakeLists.txt b/test/plugins/wasmedge_llmc/CMakeLists.txt new file mode 100644 index 00000000..7fce5e8a --- /dev/null +++ b/test/plugins/wasmedge_llmc/CMakeLists.txt @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeLLMCTests + wasmedge_llmc.cpp +) + +add_dependencies(wasmedgeLLMCTests + wasmedgePluginWasmEdgeLLMC +) + +target_include_directories(wasmedgeLLMCTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeLLMCTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeLLMCTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeLLMCTests + PRIVATE + wasmedge_shared + ) +endif() + +function(download URL OUTPUT HASH) + file(DOWNLOAD + ${URL} + ${OUTPUT} + SHOW_PROGRESS + EXPECTED_HASH ${HASH} + ) +endfunction() + +message(STATUS "Downloading GPT2 model check point to ${CMAKE_CURRENT_BINARY_DIR}/gpt2_124M.bin") +if (WASMEDGE_PLUGIN_LLMC_CUDA) + download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M_bf16.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_124M.bin + SHA256=6661f45628102b4c6e86835d9057b5ba2c024dbf9b81445175e258b7878a1a6f + ) +else() + download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_124M.bin + SHA256=3da8b207584030bcdcd207cf7a99952e3421dce92da218b351071857511bf162 + ) +endif() +message(STATUS "Downloading training dataset to ${CMAKE_CURRENT_BINARY_DIR}/tiny_shakespeare_train.bin") +download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/tiny_shakespeare_train.bin + SHA256=8a70606be574040c26d225694f5f9759973b419852d22f7fe5c118e1b359dcc8 +) +message(STATUS "Downloading validation dataset to ${CMAKE_CURRENT_BINARY_DIR}/tiny_shakespeare_val.bin") +download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/tiny_shakespeare_val.bin + SHA256=fe99db720dc7c83e694806d4e047a952909411da1daccde4ccc2e55f40882a62 +) +message(STATUS "Downloading tokenizer data to ${CMAKE_CURRENT_BINARY_DIR}/gpt2_tokenizer.bin") +download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_tokenizer.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_tokenizer.bin + SHA256=6f3abc21e444e4e8300e225f4e03da48ea121cf17e30f67009b8dad7a66c2f13 +) + +add_test(wasmedgeLLMCTests wasmedgeLLMCTests) diff --git a/test/plugins/wasmedge_llmc/wasmedge_llmc.cpp b/test/plugins/wasmedge_llmc/wasmedge_llmc.cpp new file mode 100644 index 00000000..4fcbff20 --- /dev/null +++ b/test/plugins/wasmedge_llmc/wasmedge_llmc.cpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "llmc_func.h" +#include "llmc_module.h" + +#include "common/defines.h" +#include "common/types.h" +#include "plugin/plugin.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using WasmEdge::Host::WasmEdgeLLMC::ErrNo; + +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_llmc/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeLLMC" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_llmc"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_llmc"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} +} // namespace + +template +void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); +} + +void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst.getPointer(Ptr); + *BufPtr = Value; + Ptr += 4; +} + +TEST(WasmEdgeLLMTest, TrainGPT2) { + // Create wasmedge_llmc module instance. + auto LLMCMod = createModule(); + ASSERT_TRUE(LLMCMod); + EXPECT_EQ(LLMCMod->getFuncExportNum(), 4U); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + EXPECT_NE(MemInstPtr, nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + auto *ModelCreate = LLMCMod->findFuncExports("model_create"); + EXPECT_NE(ModelCreate, nullptr); + EXPECT_TRUE(ModelCreate->isHostFunction()); + auto &HostFuncModelCreate = + dynamic_cast( + ModelCreate->getHostFunc()); + + auto *DataLoaderCreate = LLMCMod->findFuncExports("dataloader_create"); + EXPECT_NE(DataLoaderCreate, nullptr); + EXPECT_TRUE(DataLoaderCreate->isHostFunction()); + auto &HostFuncDataLoadereCreate = + dynamic_cast( + DataLoaderCreate->getHostFunc()); + + auto *TokenizerCreate = LLMCMod->findFuncExports("tokenizer_create"); + EXPECT_NE(TokenizerCreate, nullptr); + EXPECT_TRUE(TokenizerCreate->isHostFunction()); + auto &HostFuncTokenizerCreate = + dynamic_cast( + TokenizerCreate->getHostFunc()); + + auto *ModelTrain = LLMCMod->findFuncExports("model_train"); + EXPECT_NE(ModelTrain, nullptr); + EXPECT_TRUE(ModelTrain->isHostFunction()); + auto &HostFuncModelTrain = + dynamic_cast( + ModelTrain->getHostFunc()); + + std::array Errno = {UINT32_C(0)}; + + std::string CheckPointString = "./wasmedge_llmc/gpt2_124M.bin"; + std::vector CheckPointPath(CheckPointString.begin(), + CheckPointString.end()); + uint32_t CheckPointPathPtr = UINT32_C(0); + writeBinaries(MemInst, CheckPointPath, CheckPointPathPtr); + + std::string TrainDataString = "./wasmedge_llmc/tiny_shakespeare_train.bin"; + std::vector TrainDataPath(TrainDataString.begin(), + TrainDataString.end()); + uint32_t TrainDataPathPtr = CheckPointPathPtr + CheckPointPath.size(); + writeBinaries(MemInst, TrainDataPath, TrainDataPathPtr); + + std::string ValDataString = "./wasmedge_llmc/tiny_shakespeare_val.bin"; + std::vector ValDataPath(ValDataString.begin(), ValDataString.end()); + uint32_t ValDataPathPtr = TrainDataPathPtr + TrainDataPath.size(); + writeBinaries(MemInst, ValDataPath, ValDataPathPtr); + + std::string TokenizerBin = "./wasmedge_llmc/gpt2_tokenizer.bin"; + std::vector TokenizerBinPath(TokenizerBin.begin(), TokenizerBin.end()); + uint32_t TokenizerBinPtr = ValDataPathPtr + ValDataPath.size(); + writeBinaries(MemInst, TokenizerBinPath, TokenizerBinPtr); + + uint32_t ModelIdPtr = UINT32_C(0); + uint32_t ModelId = UINT32_C(0); + uint32_t TrainDataLoaderIdPtr = UINT32_C(0); + uint32_t TrainDataLoaderId = UINT32_C(0); + uint32_t ValDataLoaderIdPtr = UINT32_C(0); + uint32_t ValDataLoaderId = UINT32_C(0); + uint32_t TokenizerIdPtr = UINT32_C(0); + uint32_t TokenizerId = UINT32_C(0); + + { + EXPECT_TRUE(HostFuncModelCreate.run( + CallFrame, + std::initializer_list{ + CheckPointPathPtr, static_cast(CheckPointPath.size()), + ModelIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + ModelId = *MemInst.getPointer(ModelIdPtr); + EXPECT_EQ(ModelId, 0); + } + + { + EXPECT_TRUE(HostFuncDataLoadereCreate.run( + CallFrame, + std::initializer_list{ + TrainDataPathPtr, static_cast(TrainDataPath.size()), + /*B*/ 4, + /*T*/ 64, + /*ProcessRank*/ 0, + /*NumProcesses*/ 1, + /*ShouldShuffle*/ 1, TrainDataLoaderIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + TrainDataLoaderId = *MemInst.getPointer(TrainDataLoaderIdPtr); + EXPECT_EQ(TrainDataLoaderId, 0); + } + + { + EXPECT_TRUE(HostFuncDataLoadereCreate.run( + CallFrame, + std::initializer_list{ + ValDataPathPtr, static_cast(ValDataPath.size()), + /*B*/ 4, + /*T*/ 64, + /*ProcessRank*/ 0, + /*NumProcesses*/ 1, + /*ShouldShuffle*/ 0, ValDataLoaderIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + ValDataLoaderId = *MemInst.getPointer(ValDataLoaderIdPtr); + EXPECT_EQ(ValDataLoaderId, 1); + } + + { + EXPECT_TRUE(HostFuncTokenizerCreate.run( + CallFrame, + std::initializer_list{ + TokenizerBinPtr, static_cast(TokenizerBinPath.size()), + TokenizerIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + TokenizerId = *MemInst.getPointer(TokenizerIdPtr); + EXPECT_EQ(TokenizerId, 0); + } + + { + + EXPECT_TRUE(HostFuncModelTrain.run( + CallFrame, + std::initializer_list{ + ModelId, TrainDataLoaderId, ValDataLoaderId, TokenizerId, + /*B*/ 4, + /*T*/ 64, + /*Lr*/ 1e-4f, + /*Epoch*/ 20}, + Errno)); + } +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_opencvmini/CMakeLists.txt b/test/plugins/wasmedge_opencvmini/CMakeLists.txt new file mode 100644 index 00000000..9f0946f2 --- /dev/null +++ b/test/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeOpencvminiTests + wasmedge_opencvmini.cpp +) + +add_dependencies(wasmedgeOpencvminiTests + wasmedgePluginWasmEdgeOpenCVMini +) + +target_include_directories(wasmedgeOpencvminiTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeOpencvminiTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeOpencvminiTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeOpencvminiTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeOpencvminiTests wasmedgeOpencvminiTests) diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp new file mode 100644 index 00000000..c9b07871 --- /dev/null +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "opencvmini_func.h" +#include "opencvmini_module.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include + +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_opencvmini/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeOpenCVMini" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_opencvmini"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_opencvmini"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} + +} // namespace + +// TODO: add unit tests for every function. + +TEST(WasmEdgeOpecvminiTest, Module) { + // Create the wasmedge_opencvmini module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + EXPECT_EQ(ImgMod->getFuncExportNum(), 19U); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imdecode"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imencode"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_rectangle"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_cvt_color"), nullptr); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_process/CMakeLists.txt b/test/plugins/wasmedge_process/CMakeLists.txt new file mode 100644 index 00000000..fc389115 --- /dev/null +++ b/test/plugins/wasmedge_process/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeProcessTests + wasmedge_process.cpp +) + +add_dependencies(wasmedgeProcessTests + wasmedgePluginWasmEdgeProcess +) + +target_include_directories(wasmedgeProcessTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeProcessTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeProcessTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeProcessTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeProcessTests wasmedgeProcessTests) diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp new file mode 100644 index 00000000..ea80f906 --- /dev/null +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -0,0 +1,603 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "processfunc.h" +#include "processmodule.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_process/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeProcess" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_process"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_process"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, std::string_view Str) noexcept { + char *Buf = MemInst.getPointer(Offset); + std::copy_n(Str.data(), Str.length(), Buf); +} +} // namespace + +using namespace std::literals::string_view_literals; + +TEST(WasmEdgeProcessTest, SetProgName) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "echo". + fillMemContent(MemInst, 0, "echo"sv); + + // Get the function "wasmedge_process_set_prog_name". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_set_prog_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Name, "echo"); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncInst.run( + DummyCallFrame, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); +} + +TEST(WasmEdgeProcessTest, AddArg) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "echo". + fillMemContent(MemInst, 0, "arg1"sv); + // Set the memory[4, 8] as string "arg2". + fillMemContent(MemInst, 4, "arg2"sv); + // Set the memory[30, 41] as string "--final-arg". + fillMemContent(MemInst, 30, "--final-arg"sv); + + // Get the function "wasmedge_process_add_arg". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_arg"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to add "arg1". + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Args.size(), 1U); + EXPECT_EQ(ProcMod->getEnv().Args[0], "arg1"); + + // Test: Run function successfully to add "arg2". + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{UINT32_C(4), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Args.size(), 2U); + EXPECT_EQ(ProcMod->getEnv().Args[1], "arg2"); + + // Test: Run function successfully to add "--final-arg". + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{UINT32_C(30), UINT32_C(11)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Args.size(), 3U); + EXPECT_EQ(ProcMod->getEnv().Args[2], "--final-arg"); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncInst.run( + DummyCallFrame, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); +} + +TEST(WasmEdgeProcessTest, AddEnv) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + // Set the memory[0, 4] as string "ENV1". + fillMemContent(MemInst, 0, "ENV1"sv); + // Set the memory[4, 10] as string "VALUE1". + fillMemContent(MemInst, 4, "VALUE1"sv); + // Set the memory[30, 45] as string "LD_LIBRARY_PATH". + fillMemContent(MemInst, 30, "LD_LIBRARY_PATH"sv); + // Set the memory[50, 64] as string "/usr/local/lib". + fillMemContent(MemInst, 50, "/usr/local/lib"sv); + + // Get the function "wasmedge_process_add_env". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_env"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to add "ENV1", "VALUE1". + EXPECT_TRUE( + HostFuncInst.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(4), UINT32_C(4), UINT32_C(6)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Envs.size(), 1U); + EXPECT_EQ(ProcMod->getEnv().Envs["ENV1"], "VALUE1"); + + // Test: Run function successfully to add "LD_LIBRARY_PATH", "/usr/local/lib". + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{UINT32_C(30), UINT32_C(15), + UINT32_C(50), UINT32_C(14)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Envs.size(), 2U); + EXPECT_EQ(ProcMod->getEnv().Envs["LD_LIBRARY_PATH"], "/usr/local/lib"); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE( + HostFuncInst.run(DummyCallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(4), UINT32_C(4), UINT32_C(6)}, + {})); +} + +TEST(WasmEdgeProcessTest, AddStdIn) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "\01\02\03\04". + fillMemContent(MemInst, 0, "\01\02\03\04"sv); + // Set the memory[30, 46] as string "hello, wasmedge\n". + fillMemContent(MemInst, 30, "hello, wasmedge\n"sv); + + // Get the function "wasmedge_process_add_stdin". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_stdin"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to add "\01\02\03\04". + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().StdIn.size(), 4U); + EXPECT_EQ(ProcMod->getEnv().StdIn, + std::vector({0x01, 0x02, 0x03, 0x04})); + + // Test: Run function successfully to add "hello, wasmedge\n". + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{UINT32_C(30), UINT32_C(16)}, + {})); + EXPECT_EQ(ProcMod->getEnv().StdIn.size(), 20U); + EXPECT_EQ(ProcMod->getEnv().StdIn, + std::vector({0x01, 0x02, 0x03, 0x04, 'h', 'e', 'l', + 'l', 'o', ',', ' ', 'w', 'a', 's', + 'm', 'e', 'd', 'g', 'e', '\n'})); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncInst.run( + DummyCallFrame, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); +} + +TEST(WasmEdgeProcessTest, SetTimeOut) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Get the function "wasmedge_process_set_timeout". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_set_timeout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to set timeout 100. + EXPECT_TRUE(HostFuncInst.run( + DummyCallFrame, + std::initializer_list{UINT32_C(100)}, {})); + EXPECT_EQ(ProcMod->getEnv().TimeOut, 100U); +} + +TEST(WasmEdgeProcessTest, Run) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "\01\02\03\04". + fillMemContent(MemInst, 0, "\01\02\03\04"sv); + // Set the memory[30, 46] as string "hello, wasmedge\n". + fillMemContent(MemInst, 30, "hello, wasmedge\n"sv); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Test: Run function fails to run "c++" without allowing all commands. + ProcMod->getEnv().AllowedAll = false; + ProcMod->getEnv().Name = "c++"; + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), -1); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); + std::string ErrStr = + "Permission denied: Command \"c++\" is not in the white list. Please use " + "--allow-command=c++ or --allow-command-all to add \"c++\" command into " + "the white list.\n"; + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdErr.begin(), + ProcMod->getEnv().StdErr.end(), ErrStr.begin())); + + // Test: Run function successfully to run "c++" while allowing all commands. + ProcMod->getEnv().AllowedAll = true; + ProcMod->getEnv().Name = "c++"; + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 1); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); + + // Test: Run function successfully to run "c++" while allowing this command. + ProcMod->getEnv().AllowedAll = false; + ProcMod->getEnv().AllowedCmd.insert("c++"); + ProcMod->getEnv().Name = "c++"; + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 1); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); + + // Test: Run function successfully to run "/bin/echo" while allowing this + // command. + ProcMod->getEnv().AllowedAll = false; + ProcMod->getEnv().AllowedCmd.clear(); + ProcMod->getEnv().AllowedCmd.insert("/bin/echo"); + ProcMod->getEnv().Name = "/bin/echo"; + ProcMod->getEnv().Args.push_back("123456 test"); + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 0); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 12); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() == 0); + std::string OutStr = "123456 test\n"; + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), + ProcMod->getEnv().StdOut.end(), OutStr.begin())); +} + +TEST(WasmEdgeProcessTest, TimeoutPrecision) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncRun = dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Run "sleep 2" with a 500ms timeout. + // With the fix, the process should be killed after ~500ms. + // With the old bug (/1000000U instead of /1000U), the microsecond + // component was effectively zero, so timeout only fired at whole-second + // boundaries (i.e., at ~1000ms instead of ~500ms). + ProcMod->getEnv().AllowedAll = true; + ProcMod->getEnv().Name = "sleep"; + ProcMod->getEnv().Args.push_back("2"); + ProcMod->getEnv().TimeOut = 500; + + auto Start = std::chrono::steady_clock::now(); + EXPECT_TRUE(HostFuncRun.run(DummyCallFrame, {}, RetVal)); + auto End = std::chrono::steady_clock::now(); + auto ElapsedMs = + std::chrono::duration_cast(End - Start) + .count(); + + // The process should have been killed due to timeout. + EXPECT_EQ(RetVal[0].get(), static_cast(ETIMEDOUT)); + + // With the fix, elapsed time should be close to 500ms (the timeout value). + // Allow generous margins for CI/scheduling variance, but the key assertion + // is that it finishes well before the 2-second sleep completes and does not + // overshoot to ~1000ms (the old buggy whole-second boundary). + EXPECT_GE(ElapsedMs, 400); + EXPECT_LE(ElapsedMs, 900); +} + +TEST(WasmEdgeProcessTest, GetExitCode) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Get the function "wasmedge_process_get_exit_code". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_get_exit_code"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to get exit code. + std::array RetVal; + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 0); +} + +TEST(WasmEdgeProcessTest, GetStdOut) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncRun = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stdout_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdOutLen = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stdout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdOut = + dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Run the command "echo $(pwd)". + ProcMod->getEnv().Name = "echo"; + ProcMod->getEnv().AllowedCmd.insert("echo"); + ProcMod->getEnv().Args.push_back("$(pwd)"); + EXPECT_TRUE(HostFuncRun.run(DummyCallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 0U); + + // Test: Run wasmedge_process_get_stdout_len successfully. + EXPECT_TRUE(HostFuncGetStdOutLen.run(DummyCallFrame, {}, RetVal)); + uint32_t Len = RetVal[0].get(); + EXPECT_TRUE(Len > 0U); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncGetStdOut.run( + DummyCallFrame, std::initializer_list{UINT32_C(0)}, + {})); + + // Test: Run wasmedge_process_get_stdout successfully. + EXPECT_TRUE(HostFuncGetStdOut.run( + CallFrame, std::initializer_list{UINT32_C(0)}, {})); + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), + ProcMod->getEnv().StdOut.end(), + MemInst.getPointer(0))); +} + +TEST(WasmEdgeProcessTest, GetStdErr) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncRun = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stderr_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdErrLen = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stderr"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdErr = + dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Run the command "c++". + ProcMod->getEnv().Name = "c++"; + ProcMod->getEnv().AllowedCmd.insert("c++"); + EXPECT_TRUE(HostFuncRun.run(DummyCallFrame, {}, RetVal)); + EXPECT_NE(RetVal[0].get(), 0U); + + // Test: Run wasmedge_process_get_stdout_len successfully. + EXPECT_TRUE(HostFuncGetStdErrLen.run(DummyCallFrame, {}, RetVal)); + uint32_t Len = RetVal[0].get(); + EXPECT_TRUE(Len > 0U); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncGetStdErr.run( + DummyCallFrame, std::initializer_list{UINT32_C(0)}, + {})); + + // Test: Run wasmedge_process_get_stdout successfully. + EXPECT_TRUE(HostFuncGetStdErr.run( + CallFrame, std::initializer_list{UINT32_C(0)}, {})); + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), + ProcMod->getEnv().StdOut.end(), + MemInst.getPointer(0))); +} + +TEST(WasmEdgeProcessTest, Module) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + EXPECT_EQ(ProcMod->getEnv().ExitCode, 0U); + EXPECT_EQ(ProcMod->getFuncExportNum(), 11U); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_set_prog_name"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_add_arg"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_add_env"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_add_stdin"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_set_timeout"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_run"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_exit_code"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stdout_len"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stdout"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stderr_len"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stderr"), nullptr); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt new file mode 100644 index 00000000..dac8abd7 --- /dev/null +++ b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -0,0 +1,47 @@ +wasmedge_add_executable(wasmedgeStableDiffusionTests + wasmedge_stablediffusion.cpp +) + +add_dependencies(wasmedgeStableDiffusionTests + wasmedgePluginWasmEdgeStableDiffusion +) + +target_include_directories(wasmedgeStableDiffusionTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeStableDiffusionTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeStableDiffusionTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeStableDiffusionTests + PRIVATE + wasmedge_shared + ) +endif() +function(download URL OUTPUT HASH) + file(DOWNLOAD + ${URL} + ${OUTPUT} + SHOW_PROGRESS + EXPECTED_HASH ${HASH} + ) +endfunction() +message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/sd-v1-4.ckpt") +download( + https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt + ${CMAKE_CURRENT_BINARY_DIR}/stableDiffusion/sd-v1-4.ckpt + MD5=c01059060130b8242849d86e97212c84 +) + +add_test(wasmedgeStableDiffusionTests wasmedgeStableDiffusionTests) diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp new file mode 100644 index 00000000..e6817916 --- /dev/null +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -0,0 +1,591 @@ +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "sd_func.h" +#include "sd_module.h" + +#include +#include +#include +#include +#include +#include +#include + +using WasmEdge::Host::StableDiffusion::ErrNo; + +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_stablediffusion/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeStableDiffusion" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_stablediffusion"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_stablediffusion"sv)) { + return dynamicPointerCast(Module->create()); + } + } + return {}; +} +} // namespace + +template +void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); +} + +void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst.getPointer(Ptr); + *BufPtr = Value; + Ptr += 4; +} + +void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t PtrVal, uint32_t PtrSize, uint32_t &Ptr) { + writeUInt32(MemInst, PtrVal, Ptr); + writeUInt32(MemInst, PtrSize, Ptr); +} + +// TODO: add unit tests for every function. + +TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { + // Create the stable diffusion module instance. + auto SBMod = createModule(); + ASSERT_TRUE(SBMod); + EXPECT_EQ(SBMod->getFuncExportNum(), 4U); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(2097024))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + uint32_t SessionPtr = UINT32_C(0); + uint32_t SessionId = UINT32_C(0); + uint32_t OutputPtr = UINT32_C(0); + // uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + + // Get the function "convert". + auto *FuncInst = SBMod->findFuncExports("convert"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncConvert = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "create_context". + FuncInst = SBMod->findFuncExports("create_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCreateContext = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "text_to_image". + FuncInst = SBMod->findFuncExports("text_to_image"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncTextToImage = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "image_to_image". + FuncInst = SBMod->findFuncExports("image_to_image"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncImageToImage = + dynamic_cast( + FuncInst->getHostFunc()); + + std::string Prompt = "a lovely cat"; + std::string Prompt2 = "with blue eyes"; + std::vector SkipLayers = {7, 8, 9}; + std::string OutputPathString = "./stableDiffusion/output.png"; + std::vector OutputPath(OutputPathString.begin(), + OutputPathString.end()); + std::string InputPathString = "path:" + OutputPathString; + std::vector InputPath(InputPathString.begin(), InputPathString.end()); + std::string OutputPathString2 = "./stableDiffusion/output2.png"; + std::vector OutputPath2(OutputPathString2.begin(), + OutputPathString2.end()); + std::vector PromptData(Prompt.begin(), Prompt.end()); + std::vector PromptData2(Prompt2.begin(), Prompt2.end()); + std::string ModelPathString = "./stableDiffusion/sd-v1-4.ckpt"; + std::vector ModelPath(ModelPathString.begin(), ModelPathString.end()); + std::string QuantModelPathString = "./stableDiffusion/sd-v1-4-Q8_0.gguf"; + std::vector QuantModelPath(QuantModelPathString.begin(), + QuantModelPathString.end()); + + uint32_t ModelPathPtr = UINT32_C(0); + uint32_t QuantModelPathPtr = ModelPathPtr + ModelPath.size(); + writeBinaries(MemInst, ModelPath, ModelPathPtr); + writeBinaries(MemInst, QuantModelPath, QuantModelPathPtr); + // Test: convert -- convert successfully. + { + EXPECT_TRUE(HostFuncConvert.run( + CallFrame, + std::initializer_list{ + ModelPathPtr, static_cast(ModelPath.size()), 0, 0, + QuantModelPathPtr, static_cast(QuantModelPath.size()), + 8}, // SD_TYPE_Q8_0 = 8 + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(std::filesystem::exists(QuantModelPathString)); + } + + // Test: create_context -- create context for text to image. + { + EXPECT_TRUE(HostFuncCreateContext.run( + CallFrame, + std::initializer_list{ + QuantModelPathPtr, // ModelPathPtr + static_cast(QuantModelPath.size()), // ModelPathLen + 0, // ClipLPathPtr + 0, // ClipLPathLen + 0, // ClipGPathPtr + 0, // ClipGPathLen + 0, // T5xxlPathPtr + 0, // T5xxlPathLen + 0, // DiffusionModelPathPtr + 0, // DiffusionModelPathLen + 0, // VaePathPtr + 0, // VaePathLen + 0, // TaesdPathPtr + 0, // TaesdPathLen + 0, // ControlNetPathPtr + 0, // ControlNetPathLen + 0, // LoraModelDirPtr + 0, // LoraModelDirLen + 0, // EmbedDirPtr + 0, // EmbedDirLen + 0, // IdEmbedDirPtr + 0, // IdEmbedDirLen + 1, // VaeDecodeOnly + 0, // VaeTiling + -1, // NThreads + 36, // Wtype + 1, // RngType + 0, // Schedule + 0, // ClipOnCpu + 0, // ControlNetCpu + 0, // VaeOnCpu + 0, // DiffusionFlashAttn + SessionPtr}, // SessiontIdPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + SessionId = *MemInst.getPointer(SessionPtr); + EXPECT_EQ(SessionId, 0); + } + + // Test: text_to_image -- generate image from text. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t SkipLayersPtr = PromptPtr + PromptData.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); + writeBinaries(MemInst, OutputPath, OutputPathPtr); + EXPECT_TRUE(HostFuncTextToImage.run( + CallFrame, + std::initializer_list{ + PromptPtr, // PromptPtr + static_cast(PromptData.size()), // PromptLen + SessionId, // SessionId + 0, // ControlImagePtr + 0, // ControlImageLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + 3.5f, // Guidance + 256, // Width + 256, // Height + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 1, // SampleSteps + 42, // Seed + 1, // BatchCount + 0.90f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + EXPECT_TRUE(std::filesystem::exists(OutputPathString)); + } + + // Test: text_to_image -- reuse context to generate image from text. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t SkipLayersPtr = PromptPtr + PromptData.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); + writeBinaries(MemInst, OutputPath, OutputPathPtr); + EXPECT_TRUE(HostFuncTextToImage.run( + CallFrame, + std::initializer_list{ + PromptPtr, // PromptPtr + static_cast(PromptData.size()), // PromptLen + SessionId, // SessionId + 0, // ControlImagePtr + 0, // ControlImageLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + 3.5f, // Guidance + 256, // Width + 256, // Height + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 42, // Seed + 1, // BatchCount + 0.90f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + EXPECT_TRUE(std::filesystem::exists(OutputPathString)); + } + writeBinaries(MemInst, ModelPath, ModelPathPtr); + writeBinaries(MemInst, QuantModelPath, QuantModelPathPtr); + // Test: create_context -- create context for image to image. + { + EXPECT_TRUE(HostFuncCreateContext.run( + CallFrame, + std::initializer_list{ + QuantModelPathPtr, // ModelPathPtr + static_cast(QuantModelPath.size()), // ModelPathLen + 0, // ClipLPathPtr + 0, // ClipLPathLen + 0, // ClipGPathPtr + 0, // ClipGPathLen + 0, // T5xxlPathPtr + 0, // T5xxlPathLen + 0, // DiffusionModelPathPtr + 0, // DiffusionModelPathLen + 0, // VaePathPtr + 0, // VaePathLen + 0, // TaesdPathPtr + 0, // TaesdPathLen + 0, // ControlNetPathPtr + 0, // ControlNetPathLen + 0, // LoraModelDirPtr + 0, // LoraModelDirLen + 0, // EmbedDirPtr + 0, // EmbedDirLen + 0, // IdEmbedDirPtr + 0, // IdEmbedDirLen + 0, // VaeDecodeOnly + 0, // VaeTiling + -1, // NThreads + 36, // Wtype + 1, // RngType + 0, // Schedule + 0, // ClipOnCpu + 0, // ControlNetCpu + 0, // VaeOnCpu + 0, // DiffusionFlashAttn + SessionPtr}, // SessiontIdPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + SessionId = *MemInst.getPointer(SessionPtr); + EXPECT_EQ(SessionId, 1); + } + // Test: image_to_image -- generate image from image. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t InputPathPtr = PromptPtr + PromptData2.size(); + uint32_t SkipLayersPtr = InputPathPtr + InputPath.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData2, PromptPtr); + writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); + writeBinaries(MemInst, OutputPath2, OutputPathPtr); + EXPECT_TRUE(HostFuncImageToImage.run( + CallFrame, + std::initializer_list{ + InputPathPtr, // ImagePtr + static_cast(InputPath.size()), // ImageLen + 0, // MaskInputPtr + 0, // MaskInputLen + SessionId, // SessionId + 3.5f, // Guidance + 256, // Width + 256, // Height + 0, // ControlImagePtr + 0, // ControlImageLen + PromptPtr, // PromptPtr + static_cast(PromptData2.size()), // PromptLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 1, // SampleSteps + 0.75f, // Strength + 42, // Seed + 1, // BatchCount + 0.9f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath2.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + EXPECT_TRUE(std::filesystem::exists(OutputPathString2)); + } + // Test: image_to_image -- reuse context to generate image from image. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t InputPathPtr = PromptPtr + PromptData2.size(); + uint32_t SkipLayersPtr = InputPathPtr + InputPath.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData2, PromptPtr); + writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); + writeBinaries(MemInst, OutputPath2, OutputPathPtr); + EXPECT_TRUE(HostFuncImageToImage.run( + CallFrame, + std::initializer_list{ + InputPathPtr, // ImagePtr + static_cast(InputPath.size()), // ImageLen + 0, // MaskInputPtr + 0, // MaskInputLen + SessionId, // SessionId + 3.5f, // Guidance + 256, // Width + 256, // Height + 0, // ControlImagePtr + 0, // ControlImageLen + PromptPtr, // PromptPtr + static_cast(PromptData2.size()), // PromptLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 0.75f, // Strength + 42, // Seed + 1, // BatchCount + 0.9f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath2.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + EXPECT_TRUE(std::filesystem::exists(OutputPathString2)); + } + + // Test: text_to_image -- non exist SessionId. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t SkipLayersPtr = PromptPtr + PromptData.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); + writeBinaries(MemInst, OutputPath, OutputPathPtr); + EXPECT_TRUE(HostFuncTextToImage.run( + CallFrame, + std::initializer_list{ + PromptPtr, // PromptPtr + static_cast(PromptData.size()), // PromptLen + 99, // SessionId + 0, // ControlImagePtr + 0, // ControlImageLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + 3.5f, // Guidance + 256, // Width + 256, // Height + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 1, // SampleSteps + 42, // Seed + 1, // BatchCount + 0.90f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: image_to_image -- non exist SessionId. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t InputPathPtr = PromptPtr + PromptData2.size(); + uint32_t SkipLayersPtr = InputPathPtr + InputPath.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData2, PromptPtr); + writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); + writeBinaries(MemInst, OutputPath2, OutputPathPtr); + EXPECT_TRUE(HostFuncImageToImage.run( + CallFrame, + std::initializer_list{ + InputPathPtr, // ImagePtr + static_cast(InputPath.size()), // ImageLen + 0, // MaskInputPtr + 0, // MaskInputLen + -1, // SessionId + 3.5f, // Guidance + 256, // Width + 256, // Height + 0, // ControlImagePtr + 0, // ControlImageLen + PromptPtr, // PromptPtr + static_cast(PromptData2.size()), // PromptLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 0.75f, // Strength + 42, // Seed + 1, // BatchCount + 0.9f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath2.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_tensorflow/CMakeLists.txt b/test/plugins/wasmedge_tensorflow/CMakeLists.txt new file mode 100644 index 00000000..2104f3b3 --- /dev/null +++ b/test/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeTensorflowTests + wasmedge_tensorflow.cpp +) + +# On manylinux_2_28 with gcc-toolset and CXX11_ABI=0, LTO can cause unresolved +# internal libstdc++ symbols from libstdc++_nonshared.a (cow-fs_path.o) when +# linking against the TensorFlow C++ shared libraries. +if(CMAKE_COMPILER_IS_GNUCXX AND NOT WASMEDGE_USE_CXX11_ABI) + set_target_properties(wasmedgeTensorflowTests PROPERTIES + INTERPROCEDURAL_OPTIMIZATION OFF + ) +endif() + +add_dependencies(wasmedgeTensorflowTests + wasmedgePluginWasmEdgeTensorflow +) + +include(WASINNDeps) +wasmedge_setup_tf_target(wasmedgeTensorflowTests) + +target_include_directories(wasmedgeTensorflowTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeTensorflowTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeTensorflowTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeTensorflowTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeTensorflowTests wasmedgeTensorflowTests) diff --git a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp new file mode 100644 index 00000000..7d1cc84a --- /dev/null +++ b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "tensorflow_func.h" +#include "tensorflow_module.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_tensorflow/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeTensorflow" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_tensorflow"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_tensorflow"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} +} // namespace + +// TODO: add unit tests for every function. + +TEST(WasmEdgeTensorflowTest, Module) { + // Create the wasmedge_tensorflow module instance. + auto TFMod = createModule(); + ASSERT_TRUE(TFMod); + + EXPECT_EQ(TFMod->getFuncExportNum(), 11U); + EXPECT_NE(TFMod->findFuncExports("create_session"), nullptr); + EXPECT_NE(TFMod->findFuncExports("create_session_saved_model"), nullptr); + EXPECT_NE(TFMod->findFuncExports("delete_session"), nullptr); + EXPECT_NE(TFMod->findFuncExports("run_session"), nullptr); + EXPECT_NE(TFMod->findFuncExports("get_output_tensor"), nullptr); + EXPECT_NE(TFMod->findFuncExports("get_tensor_len"), nullptr); + EXPECT_NE(TFMod->findFuncExports("get_tensor_data"), nullptr); + EXPECT_NE(TFMod->findFuncExports("append_input"), nullptr); + EXPECT_NE(TFMod->findFuncExports("append_output"), nullptr); + EXPECT_NE(TFMod->findFuncExports("clear_input"), nullptr); + EXPECT_NE(TFMod->findFuncExports("clear_output"), nullptr); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt new file mode 100644 index 00000000..bbf2ff60 --- /dev/null +++ b/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeTensorflowLiteTests + wasmedge_tensorflowlite.cpp +) + +add_dependencies(wasmedgeTensorflowLiteTests + wasmedgePluginWasmEdgeTensorflowLite +) + +include(WASINNDeps) +wasmedge_setup_tflite_target(wasmedgeTensorflowLiteTests) + +target_include_directories(wasmedgeTensorflowLiteTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeTensorflowLiteTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeTensorflowLiteTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeTensorflowLiteTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeTensorflowLiteTests wasmedgeTensorflowLiteTests) diff --git a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp new file mode 100644 index 00000000..d8f3fc94 --- /dev/null +++ b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "tensorflowlite_func.h" +#include "tensorflowlite_module.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_tensorflowlite/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeTensorflowLite" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_tensorflowlite"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_tensorflowlite"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} + +} // namespace + +// TODO: add unit tests for every function. + +TEST(WasmEdgeTensorflowLiteTest, Module) { + // Create the wasmedge_tensorflowlite module instance. + auto TFLiteMod = createModule(); + ASSERT_TRUE(TFLiteMod); + EXPECT_EQ(TFLiteMod->getFuncExportNum(), 7U); + EXPECT_NE(TFLiteMod->findFuncExports("create_session"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("delete_session"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("run_session"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("get_output_tensor"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("get_tensor_len"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("get_tensor_data"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("append_input"), nullptr); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_zlib/CMakeLists.txt b/test/plugins/wasmedge_zlib/CMakeLists.txt new file mode 100644 index 00000000..5b9ad3ec --- /dev/null +++ b/test/plugins/wasmedge_zlib/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeZlibTests + wasmedge_zlib.cpp +) + +add_dependencies(wasmedgeZlibTests + wasmedgePluginWasmEdgeZlib +) + +target_include_directories(wasmedgeZlibTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeZlibTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeZlibTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeZlibTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeZlibTests wasmedgeZlibTests) diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp new file mode 100644 index 00000000..cb6ce197 --- /dev/null +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "zlibfunc.h" +#include "zlibmodule.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_zlib/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeZlib" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_zlib"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_zlib"sv)) { + return dynamicPointerCast( + Module->create()); + } + } + return {}; +} + +} // namespace + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +static constexpr size_t DataSize = 1 * 1024 * 1024ULL; +static constexpr size_t OutputBufferSize = 64 * 1024ULL; + +constexpr auto RandChar = []() -> char { + constexpr char Charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + constexpr size_t MaxIndex = (sizeof(Charset) - 1); + return Charset[std::rand() % MaxIndex]; +}; + +TEST(WasmEdgeZlibTest, DeflateInflateCycle) { + auto ZlibMod = createModule(); + ASSERT_TRUE(ZlibMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(16 * 64, 16 * 64))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + uint32_t + // WASM Memory Heap Pointer + WasmHP = 1, + WasmData, WasmZlibVersion, ModuleZStream, WasmCompressedData, + WasmDecompressedData; + uint32_t WasmCompressedData_size = 0, WasmDecompressedDataSize = 0; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + auto *FuncInst = ZlibMod->findFuncExports("deflateInit_"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &DeflateInit_ = FuncInst->getHostFunc(); + + FuncInst = ZlibMod->findFuncExports("deflate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &Deflate = FuncInst->getHostFunc(); + + FuncInst = ZlibMod->findFuncExports("deflateEnd"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &DeflateEnd = FuncInst->getHostFunc(); + + FuncInst = ZlibMod->findFuncExports("inflateInit_"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &InflateInit_ = FuncInst->getHostFunc(); + + FuncInst = ZlibMod->findFuncExports("inflate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &Inflate = FuncInst->getHostFunc(); + + FuncInst = ZlibMod->findFuncExports("inflateEnd"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &InflateEnd = FuncInst->getHostFunc(); + + std::array RetVal; + + WasmZlibVersion = WasmHP; + std::snprintf(MemInst.getPointer(WasmHP), + std::strlen(ZLIB_VERSION) + 1, ZLIB_VERSION); + WasmHP += std::strlen(ZLIB_VERSION); + + WasmData = WasmHP; + std::generate_n(MemInst.getPointer(WasmHP), DataSize, RandChar); + WasmHP += DataSize; + + ModuleZStream = WasmHP; + WasmZStream *strm = MemInst.getPointer(ModuleZStream); + WasmHP += sizeof(WasmZStream); + + // ----- Deflate Routine START------ + fillMemContent(MemInst, ModuleZStream, sizeof(WasmZStream), 0U); + + // deflateInit_ Test + // WASM z_stream size Mismatch + EXPECT_TRUE(DeflateInit_.run(CallFrame, + std::initializer_list{ + ModuleZStream, INT32_C(-1), WasmZlibVersion, + sizeof(WasmZStream) + 16}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + // Version Mismatch + EXPECT_TRUE(DeflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, INT32_C(-1), WasmZlibVersion + 2, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + EXPECT_TRUE(DeflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, INT32_C(-1), WasmZlibVersion, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + + WasmCompressedData = WasmHP; + + strm->AvailIn = DataSize; + strm->NextIn = WasmData; + strm->AvailOut = OutputBufferSize; + strm->NextOut = WasmCompressedData; + + // deflate Test + do { + if (strm->AvailOut == 0) { + WasmHP += OutputBufferSize; + strm->AvailOut = OutputBufferSize; + strm->NextOut = WasmHP; + } + + EXPECT_TRUE(Deflate.run(CallFrame, + std::initializer_list{ + ModuleZStream, + INT32_C(Z_FINISH), + }, + RetVal)); + EXPECT_NE(RetVal[0].get(), Z_STREAM_ERROR); + } while (RetVal[0].get() != Z_STREAM_END); + + // deflateEnd Test + EXPECT_TRUE(DeflateEnd.run( + CallFrame, std::initializer_list{ModuleZStream}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + WasmHP += OutputBufferSize - strm->AvailOut; + WasmCompressedData_size = WasmHP - WasmCompressedData; + // ----- Deflate Routine END------ + + // ----- Inflate Routine START------ + fillMemContent(MemInst, ModuleZStream, sizeof(WasmZStream), 0U); + + // inflateInit_ Test + // WASM z_stream size Mismatch + EXPECT_TRUE(InflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, WasmZlibVersion, sizeof(WasmZStream) + 16}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + // Version Mismatch + EXPECT_TRUE(InflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, WasmZlibVersion + 2, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + EXPECT_TRUE( + InflateInit_.run(CallFrame, + std::initializer_list{ + ModuleZStream, WasmZlibVersion, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + + WasmDecompressedData = WasmHP; + + strm->AvailIn = WasmCompressedData_size; + strm->NextIn = WasmCompressedData; + strm->AvailOut = OutputBufferSize; + strm->NextOut = WasmDecompressedData; + + // inflate test + do { + if (strm->AvailOut == 0) { + WasmHP += OutputBufferSize; + strm->AvailOut = OutputBufferSize; + strm->NextOut = WasmHP; + } + + EXPECT_TRUE(Inflate.run(CallFrame, + std::initializer_list{ + ModuleZStream, + INT32_C(Z_FINISH), + }, + RetVal)); + EXPECT_NE(RetVal[0].get(), Z_STREAM_ERROR); + } while (RetVal[0].get() != Z_STREAM_END); + + EXPECT_TRUE(InflateEnd.run( + CallFrame, std::initializer_list{ModuleZStream}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + WasmHP += OutputBufferSize - strm->AvailOut; + WasmDecompressedDataSize = WasmHP - WasmDecompressedData; + // ----- Inflate Routine END------ + + // Test Decompressed Buffer size against source Data size. + EXPECT_EQ(WasmDecompressedDataSize, DataSize); + // Test Decompressed Buffer content against source Data. + EXPECT_TRUE(std::equal(MemInst.getPointer(WasmDecompressedData), + MemInst.getPointer( + WasmDecompressedData + WasmDecompressedDataSize), + MemInst.getPointer(WasmData))); +} + +TEST(WasmEdgeZlibTest, Module) { + // Create the wasmedge_zlib module instance. + auto ZlibMod = createModule(); + ASSERT_TRUE(ZlibMod); + + EXPECT_TRUE(ZlibMod->getEnv().ZStreamMap.empty()); + EXPECT_EQ(ZlibMod->getFuncExportNum(), 76U); + + EXPECT_NE(ZlibMod->findFuncExports("deflateInit"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflate"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateEnd"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflate"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateEnd"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateInit2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateSetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateGetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateCopy"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateReset"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateParams"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateTune"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateBound"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflatePending"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflatePrime"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateSetHeader"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateSetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateGetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateSync"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateCopy"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateReset"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateReset2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflatePrime"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateMark"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateGetHeader"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateBackInit"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateBackEnd"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("zlibCompileFlags"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("compress"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("compress2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("compressBound"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("uncompress"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("uncompress2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzopen"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzdopen"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzbuffer"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzsetparams"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzread"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzfread"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzwrite"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzfwrite"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzputs"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzputc"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzgetc"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzungetc"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzflush"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzseek"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzrewind"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gztell"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzoffset"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzeof"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzdirect"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclose"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclose_r"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclose_w"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclearerr"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("adler32"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("adler32_z"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("adler32_combine"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("crc32"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("crc32_z"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("crc32_combine"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateInit_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateInit2_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit2_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateBackInit_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzgetc_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateSyncPoint"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateUndermine"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateValidate"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateCodesUsed"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateResetKeep"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateResetKeep"), nullptr); +} + +GTEST_API_ int main(int ArgC, char **ArgV) { + testing::InitGoogleTest(&ArgC, ArgV); + return RUN_ALL_TESTS(); +} diff --git a/thirdparty/wasi_crypto/api.hpp b/thirdparty/wasi_crypto/api.hpp new file mode 100644 index 00000000..c5eb6ddb --- /dev/null +++ b/thirdparty/wasi_crypto/api.hpp @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +/** + * THIS FILE IS AUTO-GENERATED from the following files: + * proposal_kx.witx, proposal_asymmetric_common.witx, proposal_common.witx, proposal_signatures.witx, proposal_symmetric.witx, proposal_external_secrets.witx + * + * @file + * This file describes the [WASI] interface, consisting of functions, types, + * and defined values (macros). + * + * The interface described here is greatly inspired by [CloudABI]'s clean, + * thoughtfully-designed, capability-oriented, POSIX-style API. + * + * [CloudABI]: https://github.com/NuxiNL/cloudlibc + * [WASI]: https://github.com/WebAssembly/WASI/ + */ + +#pragma once + +#include +#include +#include + +using const_uint8_t_ptr = uint32_t; +using uint8_t_ptr = uint32_t; + +#define DEFINE_ENUM_OPERATORS(type) \ + inline constexpr type operator~(type a) noexcept { \ + return static_cast(~static_cast>(a)); \ + } \ + inline constexpr type operator|(type a, type b) noexcept { \ + return static_cast(static_cast>(a) | \ + static_cast>(b)); \ + } \ + inline constexpr type &operator|=(type &a, type b) noexcept { \ + a = a | b; \ + return a; \ + } \ + inline constexpr type operator&(type a, type b) noexcept { \ + return static_cast(static_cast>(a) & \ + static_cast>(b)); \ + } \ + inline constexpr type &operator&=(type &a, type b) noexcept { \ + a = a & b; \ + return a; \ + } + +static_assert(alignof(int8_t) == 1, "non-wasi data layout"); +static_assert(alignof(uint8_t) == 1, "non-wasi data layout"); +static_assert(alignof(int16_t) == 2, "non-wasi data layout"); +static_assert(alignof(uint16_t) == 2, "non-wasi data layout"); +static_assert(alignof(int32_t) == 4, "non-wasi data layout"); +static_assert(alignof(uint32_t) == 4, "non-wasi data layout"); +static_assert(alignof(int64_t) == 8, "non-wasi data layout"); +static_assert(alignof(uint64_t) == 8, "non-wasi data layout"); +static_assert(alignof(const_uint8_t_ptr) == 4, "non-wasi data layout"); +static_assert(alignof(uint8_t_ptr) == 4, "non-wasi data layout"); + +/** + * Error codes. + */ +enum __wasi_crypto_errno_e_t : uint16_t { + /** + * Operation succeeded. + */ + __WASI_CRYPTO_ERRNO_SUCCESS = 0, + + /** + * An error occurred when trying to during a conversion from a host type to a guest type. + * + * Only an internal bug can throw this error. + */ + __WASI_CRYPTO_ERRNO_GUEST_ERROR = 1, + + /** + * The requested operation is valid, but not implemented by the host. + */ + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED = 2, + + /** + * The requested feature is not supported by the chosen algorithm. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_FEATURE = 3, + + /** + * The requested operation is valid, but was administratively prohibited. + */ + __WASI_CRYPTO_ERRNO_PROHIBITED_OPERATION = 4, + + /** + * Unsupported encoding for an import or export operation. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING = 5, + + /** + * The requested algorithm is not supported by the host. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM = 6, + + /** + * The requested option is not supported by the currently selected algorithm. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION = 7, + + /** + * An invalid or incompatible key was supplied. + * + * The key may not be valid, or was generated for a different algorithm or parameters set. + */ + __WASI_CRYPTO_ERRNO_INVALID_KEY = 8, + + /** + * The currently selected algorithm doesn't support the requested output length. + * + * This error is thrown by non-extensible hash functions, when requesting an output size larger than they produce out of a single block. + */ + __WASI_CRYPTO_ERRNO_INVALID_LENGTH = 9, + + /** + * A signature or authentication tag verification failed. + */ + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED = 10, + + /** + * A secure random numbers generator is not available. + * + * The requested operation requires random numbers, but the host cannot securely generate them at the moment. + */ + __WASI_CRYPTO_ERRNO_RNG_ERROR = 11, + + /** + * An error was returned by the underlying cryptography library. + * + * The host may be running out of memory, parameters may be incompatible with the chosen implementation of an algorithm or another unexpected error may have happened. + * + * Ideally, the specification should provide enough details and guidance to make this error impossible to ever be thrown. + * + * Realistically, the WASI crypto module cannot possibly cover all possible error types implementations can return, especially since some of these may be language-specific. + * This error can thus be thrown when other error types are not suitable, and when the original error comes from the cryptographic primitives themselves and not from the WASI module. + */ + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE = 12, + + /** + * The supplied signature is invalid, or incompatible with the chosen algorithm. + */ + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE = 13, + + /** + * An attempt was made to close a handle that was already closed. + */ + __WASI_CRYPTO_ERRNO_CLOSED = 14, + + /** + * A function was called with an unassigned handle, a closed handle, or handle of an unexpected type. + */ + __WASI_CRYPTO_ERRNO_INVALID_HANDLE = 15, + + /** + * The host needs to copy data to a guest-allocated buffer, but that buffer is too small. + */ + __WASI_CRYPTO_ERRNO_OVERFLOW = 16, + + /** + * An internal error occurred. + * + * This error is reserved to internal consistency checks, and must only be sent if the internal state of the host remains safe after an inconsistency was detected. + */ + __WASI_CRYPTO_ERRNO_INTERNAL_ERROR = 17, + + /** + * Too many handles are currently open, and a new one cannot be created. + * + * Implementations are free to represent handles as they want, and to enforce limits to limit resources usage. + */ + __WASI_CRYPTO_ERRNO_TOO_MANY_HANDLES = 18, + + /** + * A key was provided, but the chosen algorithm doesn't support keys. + * + * This is returned by symmetric operations. + * + * Many hash functions, in particular, do not support keys without being used in particular constructions. + * Blindly ignoring a key provided by mistake while trying to open a context for such as function could cause serious security vulnerabilities. + * + * These functions must refuse to create the context and return this error instead. + */ + __WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED = 19, + + /** + * A key is required for the chosen algorithm, but none was given. + */ + __WASI_CRYPTO_ERRNO_KEY_REQUIRED = 20, + + /** + * The provided authentication tag is invalid or incompatible with the current algorithm. + * + * This error is returned by decryption functions and tag verification functions. + * + * Unlike `verification_failed`, this error code is returned when the tag cannot possibly verify for any input. + */ + __WASI_CRYPTO_ERRNO_INVALID_TAG = 21, + + /** + * The requested operation is incompatible with the current scheme. + * + * For example, the `symmetric_state_encrypt()` function cannot complete if the selected construction is a key derivation function. + * This error code will be returned instead. + */ + __WASI_CRYPTO_ERRNO_INVALID_OPERATION = 22, + + /** + * A nonce is required. + * + * Most encryption schemes require a nonce. + * + * In the absence of a nonce, the WASI cryptography module can automatically generate one, if that can be done safely. The nonce can be retrieved later with the `symmetric_state_option_get()` function using the `nonce` parameter. + * If automatically generating a nonce cannot be done safely, the module never falls back to an insecure option and requests an explicit nonce by throwing that error. + */ + __WASI_CRYPTO_ERRNO_NONCE_REQUIRED = 23, + + /** + * The provided nonce doesn't have a correct size for the given cipher. + */ + __WASI_CRYPTO_ERRNO_INVALID_NONCE = 24, + + /** + * The named option was not set. + * + * The caller tried to read the value of an option that was not set. + * This error is used to make the distinction between an empty option, and an option that was not set and left to its default value. + */ + __WASI_CRYPTO_ERRNO_OPTION_NOT_SET = 25, + + /** + * A key or key pair matching the requested identifier cannot be found using the supplied information. + * + * This error is returned by a secrets manager via the `keypair_from_id()` function. + */ + __WASI_CRYPTO_ERRNO_NOT_FOUND = 26, + + /** + * The algorithm requires parameters that haven't been set. + * + * Non-generic options are required and must be given by building an `options` set and giving that object to functions instantiating that algorithm. + */ + __WASI_CRYPTO_ERRNO_PARAMETERS_MISSING = 27, + + /** + * A requested computation is not done yet, and additional calls to the function are required. + * + * Some functions, such as functions generating key pairs and password stretching functions, can take a long time to complete. + * + * In order to avoid a host call to be blocked for too long, these functions can return prematurely, requiring additional calls with the same parameters until they complete. + */ + __WASI_CRYPTO_ERRNO_IN_PROGRESS = 28, + + /** + * Multiple keys have been provided, but they do not share the same type. + * + * This error is returned when trying to build a key pair from a public key and a secret key that were created for different and incompatible algorithms. + */ + __WASI_CRYPTO_ERRNO_INCOMPATIBLE_KEYS = 29, + + /** + * A managed key or secret expired and cannot be used any more. + */ + __WASI_CRYPTO_ERRNO_EXPIRED = 30, + +}; +static_assert(sizeof(__wasi_crypto_errno_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_crypto_errno_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a key pair. + */ +enum __wasi_keypair_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_KEYPAIR_ENCODING_RAW = 0, + + /** + * PCSK8/DER encoding. + */ + __WASI_KEYPAIR_ENCODING_PKCS8 = 1, + + /** + * PEM encoding. + */ + __WASI_KEYPAIR_ENCODING_PEM = 2, + + /** + * Implementation-defined encoding. + */ + __WASI_KEYPAIR_ENCODING_LOCAL = 3, + +}; +static_assert(sizeof(__wasi_keypair_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_keypair_encoding_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a public key. + */ +enum __wasi_publickey_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_PUBLICKEY_ENCODING_RAW = 0, + + /** + * PKCS8/DER encoding. + */ + __WASI_PUBLICKEY_ENCODING_PKCS8 = 1, + + /** + * PEM encoding. + */ + __WASI_PUBLICKEY_ENCODING_PEM = 2, + + /** + * SEC-1 encoding. + */ + __WASI_PUBLICKEY_ENCODING_SEC = 3, + + /** + * Implementation-defined encoding. + */ + __WASI_PUBLICKEY_ENCODING_LOCAL = 4, + +}; +static_assert(sizeof(__wasi_publickey_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_publickey_encoding_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a secret key. + */ +enum __wasi_secretkey_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_SECRETKEY_ENCODING_RAW = 0, + + /** + * PKCS8/DER encoding. + */ + __WASI_SECRETKEY_ENCODING_PKCS8 = 1, + + /** + * PEM encoding. + */ + __WASI_SECRETKEY_ENCODING_PEM = 2, + + /** + * SEC-1 encoding. + */ + __WASI_SECRETKEY_ENCODING_SEC = 3, + + /** + * Implementation-defined encoding. + */ + __WASI_SECRETKEY_ENCODING_LOCAL = 4, + +}; +static_assert(sizeof(__wasi_secretkey_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_secretkey_encoding_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a signature. + */ +enum __wasi_signature_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_SIGNATURE_ENCODING_RAW = 0, + + /** + * DER encoding. + */ + __WASI_SIGNATURE_ENCODING_DER = 1, + +}; +static_assert(sizeof(__wasi_signature_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_signature_encoding_e_t) == 2, "witx calculated align"); + +/** + * An algorithm category. + */ +enum __wasi_algorithm_type_e_t : uint16_t { + __WASI_ALGORITHM_TYPE_SIGNATURES = 0, + + __WASI_ALGORITHM_TYPE_SYMMETRIC = 1, + + __WASI_ALGORITHM_TYPE_KEY_EXCHANGE = 2, + +}; +static_assert(sizeof(__wasi_algorithm_type_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_algorithm_type_e_t) == 2, "witx calculated align"); + +/** + * Version of a managed key. + * + * A version can be an arbitrary `u64` integer, with the exception of some reserved values. + */ +using __wasi_version_t = uint64_t; + +static_assert(sizeof(__wasi_version_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_version_t) == 8, "witx calculated align"); + +/** + * Size of a value. + */ +using __wasi_size_t = uint32_t; + +static_assert(sizeof(__wasi_size_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_size_t) == 4, "witx calculated align"); + +/** + * A UNIX timestamp, in seconds since 01/01/1970. + */ +using __wasi_timestamp_t = uint64_t; + +static_assert(sizeof(__wasi_timestamp_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_timestamp_t) == 8, "witx calculated align"); + +/** + * Handle for functions returning output whose size may be large or not known in advance. + * + * An `array_output` object contains a host-allocated byte array. + * + * A guest can get the size of that array after a function returns in order to then allocate a buffer of the correct size. + * In addition, the content of such an object can be consumed by a guest in a streaming fashion. + * + * An `array_output` handle is automatically closed after its full content has been consumed. + */ +using __wasi_array_output_t = int32_t; + +static_assert(sizeof(__wasi_array_output_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_array_output_t) == 4, "witx calculated align"); + +/** + * A set of options. + * + * This type is used to set non-default parameters. + * + * The exact set of allowed options depends on the algorithm being used. + */ +using __wasi_options_t = int32_t; + +static_assert(sizeof(__wasi_options_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_options_t) == 4, "witx calculated align"); + +/** + * A handle to the optional secrets management facilities offered by a host. + * + * This is used to generate, retrieve and invalidate managed keys. + */ +using __wasi_secrets_manager_t = int32_t; + +static_assert(sizeof(__wasi_secrets_manager_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_secrets_manager_t) == 4, "witx calculated align"); + +/** + * A key pair. + */ +using __wasi_keypair_t = int32_t; + +static_assert(sizeof(__wasi_keypair_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_keypair_t) == 4, "witx calculated align"); + +/** + * A state to absorb data to be signed. + * + * After a signature has been computed or verified, the state remains valid for further operations. + * + * A subsequent signature would sign all the data accumulated since the creation of the state object. + */ +using __wasi_signature_state_t = int32_t; + +static_assert(sizeof(__wasi_signature_state_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_state_t) == 4, "witx calculated align"); + +/** + * A signature. + */ +using __wasi_signature_t = int32_t; + +static_assert(sizeof(__wasi_signature_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_t) == 4, "witx calculated align"); + +/** + * A public key, for key exchange and signature verification. + */ +using __wasi_publickey_t = int32_t; + +static_assert(sizeof(__wasi_publickey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_publickey_t) == 4, "witx calculated align"); + +/** + * A secret key, for key exchange mechanisms. + */ +using __wasi_secretkey_t = int32_t; + +static_assert(sizeof(__wasi_secretkey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_secretkey_t) == 4, "witx calculated align"); + +/** + * A state to absorb signed data to be verified. + */ +using __wasi_signature_verification_state_t = int32_t; + +static_assert(sizeof(__wasi_signature_verification_state_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_verification_state_t) == 4, "witx calculated align"); + +/** + * A state to perform symmetric operations. + * + * The state is not reset nor invalidated after an option has been performed. + * Incremental updates and sessions are thus supported. + */ +using __wasi_symmetric_state_t = int32_t; + +static_assert(sizeof(__wasi_symmetric_state_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_symmetric_state_t) == 4, "witx calculated align"); + +/** + * A symmetric key. + * + * The key can be imported from raw bytes, or can be a reference to a managed key. + * + * If it was imported, the host will wipe it from memory as soon as the handle is closed. + */ +using __wasi_symmetric_key_t = int32_t; + +static_assert(sizeof(__wasi_symmetric_key_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_symmetric_key_t) == 4, "witx calculated align"); + +/** + * An authentication tag. + * + * This is an object returned by functions computing authentication tags. + * + * A tag can be compared against another tag (directly supplied as raw bytes) in constant time with the `symmetric_tag_verify()` function. + * + * This object type can't be directly created from raw bytes. They are only returned by functions computing MACs. + * + * The host is responsible for securely wiping them from memory on close. + */ +using __wasi_symmetric_tag_t = int32_t; + +static_assert(sizeof(__wasi_symmetric_tag_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_symmetric_tag_t) == 4, "witx calculated align"); + +/** + * Options index, only required by the Interface Types translation layer. + */ +enum __wasi_opt_options_u_e_t : uint8_t { + __WASI_OPT_OPTIONS_U_SOME = 0, + + __WASI_OPT_OPTIONS_U_NONE = 1, + +}; +static_assert(sizeof(__wasi_opt_options_u_e_t) == 1, "witx calculated size"); +static_assert(alignof(__wasi_opt_options_u_e_t) == 1, "witx calculated align"); + +/** + * An optional options set. + * + * This union simulates an `Option` type to make the `options` parameter of some functions optional. + */ +union __wasi_opt_options_u_t { + __wasi_options_t some; +}; +struct __wasi_opt_options_t { + __wasi_opt_options_u_e_t tag; + __wasi_opt_options_u_t u; +}; + +static_assert(sizeof(__wasi_opt_options_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_opt_options_t) == 4, "witx calculated align"); +static_assert(offsetof(__wasi_opt_options_t, u) == 4, "witx calculated union offset"); + +/** + * Symmetric key index, only required by the Interface Types translation layer. + */ +enum __wasi_opt_symmetric_key_u_e_t : uint8_t { + __WASI_OPT_SYMMETRIC_KEY_U_SOME = 0, + + __WASI_OPT_SYMMETRIC_KEY_U_NONE = 1, + +}; +static_assert(sizeof(__wasi_opt_symmetric_key_u_e_t) == 1, "witx calculated size"); +static_assert(alignof(__wasi_opt_symmetric_key_u_e_t) == 1, "witx calculated align"); + +/** + * An optional symmetric key. + * + * This union simulates an `Option` type to make the `symmetric_key` parameter of some functions optional. + */ +union __wasi_opt_symmetric_key_u_t { + __wasi_symmetric_key_t some; +}; +struct __wasi_opt_symmetric_key_t { + __wasi_opt_symmetric_key_u_e_t tag; + __wasi_opt_symmetric_key_u_t u; +}; + +static_assert(sizeof(__wasi_opt_symmetric_key_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_opt_symmetric_key_t) == 4, "witx calculated align"); +static_assert(offsetof(__wasi_opt_symmetric_key_t, u) == 4, "witx calculated union offset"); + +using __wasi_u64_t = uint64_t; + +static_assert(sizeof(__wasi_u64_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_u64_t) == 8, "witx calculated align"); + +/** + * `$kx_keypair` is just an alias for `$keypair` + * + * However, bindings may want to define a specialized type `kx_keypair` as a super class of `keypair`. + */ +using __wasi_kx_keypair_t = __wasi_keypair_t; + +static_assert(sizeof(__wasi_kx_keypair_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_kx_keypair_t) == 4, "witx calculated align"); + +/** + * `$kx_publickey` is just an alias for `$publickey` + * + * However, bindings may want to define a specialized type `kx_publickey` as a super class of `publickey`, with additional methods such as `dh`. + */ +using __wasi_kx_publickey_t = __wasi_publickey_t; + +static_assert(sizeof(__wasi_kx_publickey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_kx_publickey_t) == 4, "witx calculated align"); + +/** + * `$kx_secretkey` is just an alias for `$secretkey` + * + * However, bindings may want to define a specialized type `kx_secretkey` as a super class of `secretkeykey`, with additional methods such as `dh`. + */ +using __wasi_kx_secretkey_t = __wasi_secretkey_t; + +static_assert(sizeof(__wasi_kx_secretkey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_kx_secretkey_t) == 4, "witx calculated align"); + +/** + * `$signature_keypair` is just an alias for `$keypair` + * + * However, bindings may want to define a specialized type `signature_keypair` as a super class of `keypair`, with additional methods such as `sign`. + */ +using __wasi_signature_keypair_t = __wasi_keypair_t; + +static_assert(sizeof(__wasi_signature_keypair_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_keypair_t) == 4, "witx calculated align"); + +/** + * `$signature_publickey` is just an alias for `$publickey` + * + * However, bindings may want to define a specialized type `signature_publickey` as a super class of `publickey`, with additional methods such as `verify`. + */ +using __wasi_signature_publickey_t = __wasi_publickey_t; + +static_assert(sizeof(__wasi_signature_publickey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_publickey_t) == 4, "witx calculated align"); + +/** + * `$signature_secretkey` is just an alias for `$secretkey` + * + * However, bindings may want to define a specialized type `signature_secretkey` as a super class of `secretkey`. + */ +using __wasi_signature_secretkey_t = __wasi_secretkey_t; + +static_assert(sizeof(__wasi_signature_secretkey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_secretkey_t) == 4, "witx calculated align"); diff --git a/utils/build_libpiper.sh b/utils/build_libpiper.sh new file mode 100644 index 00000000..05f4ab50 --- /dev/null +++ b/utils/build_libpiper.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -e + +echo "::group::Build and install libpiper" +rm -rf piper-source + +git clone https://github.com/OHF-Voice/piper1-gpl piper-source +cd piper-source/libpiper +git checkout 32b95f8c1f0dc0ce27a6acd1143de331f61af777 +cmake -Bbuild-deps \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="$PWD/install" \ + -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + +cmake --build build-deps +cmake --install build-deps +echo "::endgroup::" diff --git a/utils/docker/Dockerfile.alpine-base b/utils/docker/Dockerfile.alpine-base new file mode 100644 index 00000000..c1e587ec --- /dev/null +++ b/utils/docker/Dockerfile.alpine-base @@ -0,0 +1,62 @@ +# syntax=docker/dockerfile:1.5-labs + +ARG ALPINE_VERSION=3.23 + +# ---------------------------------------------------- +# STAGE 1: Builder +# Compiles LLD from source, linking against system LLVM +# ---------------------------------------------------- +FROM alpine:${ALPINE_VERSION} AS builder + +# Use a specific tag for reproducibility +# Pinned version to match Alpine 3.23's system LLVM (v19) +ARG LLVM_VERSION=llvmorg-19.1.5 + +# Build Dependencies: +# - build-base: Compiler toolchain (gcc, g++, make, libc-dev) required for compilation. +# - cmake: The build system used by LLVM. +# - ninja: The build generator (faster than make). +# - python3: Required by LLVM build scripts and configuration utilities. +# - git: Required to clone the LLVM repository. +# - linux-headers: Ensures system header compatibility for C++ standard library builds. +# - llvm19-dev: Provides LLVM development headers for version 19. +# - llvm19-static: Provides static LLVM libraries for version 19. +# - zstd-static: Provides static zstd compression libraries. +# - zlib-static: Provides static zlib compression libraries. +RUN apk add --no-cache build-base cmake ninja python3 git linux-headers \ + llvm19-dev llvm19-static zstd-static zlib-static + + +# The system CMake files require testing libraries that are missing in the Alpine package. +# We create empty dummy archives to satisfy the dependency check. +RUN LLVM_LIB_DIR=/usr/lib/llvm19/lib && \ + for f in libLLVMTestingAnnotations.a libLLVMTestingSupport.a libllvm_gtest.a libllvm_gtest_main.a; do \ + ar rc "$LLVM_LIB_DIR/$f"; \ + done + +# Clone LLVM Project +WORKDIR /src +# Shallow clone specific tag +RUN git clone --depth 1 -b ${LLVM_VERSION} https://github.com/llvm/llvm-project.git + + +# Configure and Build LLD +# We treat LLD as a standalone project and point it to the system LLVM +WORKDIR /src/llvm-project/lld +RUN cmake -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DCMAKE_INSTALL_PREFIX=/usr/local \ + -DLLVM_DIR=/usr/lib/llvm19/lib/cmake/llvm \ + -DLLVM_BUILD_STATIC=ON \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_EXAMPLES=OFF && \ + cmake --build build --target lldELF lldCommon lldMachO lldCOFF lldWasm lldMinGW + +# ---------------------------------------------------- +# STAGE 2: Final Image (Contains only the compiled static libraries) +# ---------------------------------------------------- +FROM alpine:${ALPINE_VERSION} + +# Copy artifacts to a standard library path +COPY --from=builder /src/llvm-project/lld/build/lib/liblld*.a /usr/local/lib/ + diff --git a/utils/docker/Dockerfile.alpine-static b/utils/docker/Dockerfile.alpine-static new file mode 100644 index 00000000..37e144df --- /dev/null +++ b/utils/docker/Dockerfile.alpine-static @@ -0,0 +1,116 @@ +# syntax=docker/dockerfile:1.5-labs + +ARG XX_VERSION=1.2.1 +ARG ALPINE_VERSION=3.23 + + +# ---------------------------------------------------- +# STAGE 1: Import Pre-Built LLD Libraries +# ---------------------------------------------------- +FROM wasmedge/wasmedge:alpine-base-3.23 AS lld-artifact + +# ---------------------------------------------------- +# STAGE 2: Main WasmEdge Build +# ---------------------------------------------------- +FROM --platform=$BUILDPLATFORM tonistiigi/xx:${XX_VERSION} AS xx +FROM --platform=$BUILDPLATFORM alpine:${ALPINE_VERSION} AS base +COPY --from=xx / / + +# Install host dependencies +RUN apk add bash cmake samurai g++ clang +SHELL [ "bash", "-c" ] + +# Make a cmake toolchain file +RUN cat <<'EOT' > /usr/bin/xx-toolchain && chmod a+x /usr/bin/xx-toolchain +#!/bin/bash +mkdir -p /etc/xx-toolchains/ +TOOLCHAIN_FILE="/etc/xx-toolchains/$(xx-clang --print-target-triple).cmake" +[ -f "$TOOLCHAIN_FILE" ] || cat < "$TOOLCHAIN_FILE" +set(CMAKE_CROSSCOMPILING ON) +set(CMAKE_SYSROOT "$(xx-info sysroot)") +set(CMAKE_SYSTEM_NAME "Linux") +set(CMAKE_SYSTEM_VERSION 1) +set(CMAKE_SYSTEM_PROCESSOR "$(xx-info march)") +set(CMAKE_C_COMPILER "$(which clang)") +set(CMAKE_CXX_COMPILER "$(which clang++)") +set(CMAKE_ASM_COMPILER "$(which clang)") +set(CMAKE_AR "$(which ar)") +set(CMAKE_RANLIB "$(which ranlib)") +set(PKG_CONFIG_EXECUTABLE "$(xx-clang --print-prog-name=pkg-config)") +set(CMAKE_C_COMPILER_TARGET "$(xx-clang --print-target-triple)") +set(CMAKE_CXX_COMPILER_TARGET "$(xx-clang --print-target-triple)") +set(CMAKE_ASM_COMPILER_TARGET "$(xx-clang --print-target-triple)") +EOF +echo "$TOOLCHAIN_FILE" +EOT + +FROM base AS config +ARG TARGETPLATFORM + +RUN apk add git llvm-dev +# We add zstd-static because the SYSTEM LLVM (installed via apk) depends on it. +RUN xx-apk add \ + g++ \ + llvm-dev llvm-static \ + lld lld-dev \ + zlib-dev zlib-static \ + zstd-static + +# Fix LLVMConfig.cmake to account for the sysroot +RUN sed -i 's|/usr/lib/llvm|${CMAKE_SYSROOT}usr/lib/llvm|' $(xx-info sysroot)usr/lib/cmake/llvm*/LLVMConfig.cmake + +# Patch: Create dummy files for missing testing libraries to satisfy CMake +RUN for d in $(xx-info sysroot)usr/lib/llvm*/lib; do \ + for f in libLLVMTestingAnnotations.a libLLVMTestingSupport.a libllvm_gtest.a libllvm_gtest_main.a; do \ + ar rc "$d/$f"; \ + done \ + done + +# Inject the Native LLD libraries from Stage 1 (The Base Image) +# Note: We changed the source path to /usr/local/lib because that's where we put them in the base image +COPY --from=lld-artifact /usr/local/lib/liblld*.a /tmp/lld-libs/ +RUN cp /tmp/lld-libs/*.a $(xx-info sysroot)usr/lib/llvm*/lib/ + +# Use the native version of llvm-config +RUN ! xx-info is-cross || cp -f /usr/lib/llvm*/bin/llvm-config $(xx-info sysroot)usr/lib/llvm*/bin/llvm-config + +RUN --mount=type=bind,target=/src,source=. \ + cmake -S /src -B /build -G Ninja \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_INSTALL_PREFIX="/install" \ + # For cross compiling + -DCMAKE_TOOLCHAIN_FILE="$(xx-toolchain)" \ + -DWASMEDGE_BUILD_PACKAGE="TGZ" \ + -DWASMEDGE_USE_LLVM=ON \ + # Build just what we need + -DWASMEDGE_BUILD_STATIC_LIB=ON \ + -DWASMEDGE_BUILD_TESTS=OFF \ + -DWASMEDGE_BUILD_SHARED_LIB=OFF \ + -DWASMEDGE_BUILD_TOOLS=OFF \ + -DWASMEDGE_BUILD_PLUGINS=OFF \ + -DWASMEDGE_BUILD_EXAMPLE=OFF \ + # link llvm statically + -DWASMEDGE_LINK_LLVM_STATIC=ON \ + -DWASMEDGE_LINK_TOOLS_STATIC=ON \ + # Disable extra dependencies + -DWASMEDGE_DISABLE_LIBTINFO=ON + +FROM config AS build +RUN --mount=type=bind,target=/src,source=. \ + cmake --build /build -- install package + +ARG ALPINE_VERSION +RUN --mount=type=bind,target=/src,source=. < /etc/apt/trusted.gpg.d/apt.llvm.org.asc + add-apt-repository "deb http://apt.llvm.org/$(lsb_release -sc)/ llvm-toolchain-$(lsb_release -sc)-${LLVM_VERSION} main" + apt-get update -y +EOT + +FROM base as deps +ARG TARGETPLATFORM LLVM_VERSION + +# Install llvm-*-dev: +RUN /bin/bash < package depends on llvm-*: and llvm-*-tools: but + # should depend on llvm-*-tools: and llvm-*:. + # Patch llvm-*-dev: to replace those dependencies. + # See: https://groups.google.com/g/linux.debian.bugs.dist/c/OrHgd5vY278 + cd $(mktemp -d) + xx-apt show llvm-${LLVM_VERSION}-dev + apt-get download llvm-${LLVM_VERSION}-dev:$(xx-info debian-arch) + ar x llvm-${LLVM_VERSION}-dev_*.deb + tar xJf control.tar.* + sed -Ei 's|(llvm-[0-9]*(-tools)?) |\1:'$(TARGETPLATFORM='' TARGETPAIR='' TARGETOS='' TARGETARCH='' TARGETVARIANT='' xx-info debian-arch)' |g' control + tar --ignore-failed-read -czf control.tar.gz {post,pre}{inst,rm} md5sums control + ar rcs llvm-dev-patched.deb debian-binary control.tar.gz data.tar.* + apt-get install --no-install-recommends -y ./llvm-dev-patched.deb +EOT + +# Install other *: dev dependencies +RUN xx-apt-get install --no-install-recommends -y \ + xx-cxx-essentials \ + liblld-${LLVM_VERSION}-dev libpolly-${LLVM_VERSION}-dev \ + libncurses5-dev zlib1g-dev + +# Make a cmake toolchain file +RUN cat < /toolchain.cmake + set(CMAKE_SYSTEM_NAME "Linux") + set(CMAKE_SYSTEM_VERSION 1) + set(CMAKE_SYSTEM_PROCESSOR "$(xx-info march)") + set(CMAKE_C_COMPILER "xx-clang") + set(CMAKE_CXX_COMPILER "xx-clang++") + set(CMAKE_ASM_COMPILER "xx-clang") + set(CMAKE_AR "ar") + set(DPKG_CONFIG_EXECUTABLE $(xx-clang --print-prog-name=pkg-config)) + set(DCMAKE_C_COMPILER_TARGET $(xx-clang --print-target-triple)) + set(DCMAKE_CXX_COMPILER_TARGET $(xx-clang --print-target-triple)) + set(DCMAKE_ASM_COMPILER_TARGET $(xx-clang --print-target-triple)) +EOT + +FROM deps as src +ADD . /src + +FROM src as build +RUN cmake -S /src -B /build -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=/toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_INSTALL_PREFIX=/install \ + -DWASMEDGE_BUILD_PACKAGE="TGZ" \ + -DWASMEDGE_BUILD_TESTS=OFF \ + -DWASMEDGE_BUILD_SHARED_LIB=OFF \ + -DWASMEDGE_BUILD_STATIC_LIB=ON \ + -DWASMEDGE_BUILD_TOOLS=OFF \ + -DWASMEDGE_BUILD_PLUGINS=OFF \ + -DWASMEDGE_BUILD_EXAMPLE=OFF \ + -DWASMEDGE_LINK_LLVM_STATIC=ON \ + -DWASMEDGE_LINK_TOOLS_STATIC=ON \ + -DWASMEDGE_USE_LLVM=ON + +RUN cmake --build /build -- install +RUN cmake --build /build -- package +RUN mv /build/WasmEdge-*-Linux.tar.gz $(echo build/WasmEdge-*-Linux.tar.gz | sed 's|WasmEdge-\(.*\)-Linux.tar.gz|WasmEdge-\1-debian'$(lsb_release -sr)'_'$(xx-info march)'_static.tar.gz|') + +FROM scratch as tar +COPY --from=build /build/WasmEdge-*.tar.gz / diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps new file mode 100644 index 00000000..aa479a0d --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -0,0 +1,38 @@ +ARG BASE=wasmedge/wasmedge:manylinux2014_x86_64 +FROM ${BASE} + +ENV PATH /opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +RUN cd && (yum check-update || true) && \ + yum install -y cmake wget unzip zlib-devel zlib-static +RUN yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ + yum install -y gh + +WORKDIR /root + +COPY opencvmini/install-opencvmini.sh . +ENV OPENCV_VERSION "4.8.0" +RUN [ "/bin/bash", "install-opencvmini.sh" ] + +COPY wasi-nn/install-pytorch.sh . +ENV PYTORCH_VERSION "2.5.1" +ENV PYTORCH_INSTALL_TO "/root" +ENV Torch_DIR "/root/libtorch" +RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] + +COPY wasi-crypto/build-openssl.sh . +ENV OPENSSL_ROOT_DIR "/root/openssl-1.1.1n/openssl" +RUN [ "/bin/bash", "build-openssl.sh" ] + +COPY ffmpeg/install-ffmpeg-v7.1.1.sh . +RUN [ "/bin/bash", "install-ffmpeg-v7.1.1.sh" ] +ENV PKG_CONFIG_PATH /root/FFmpeg-n7.1.1/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH /root/FFmpeg-n7.1.1/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +ENV OPENVINO_VERSION "2025.0.0" +ENV OPENVINO_YEAR "2025" + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 new file mode 100644 index 00000000..30a3e87a --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +FROM quay.io/pypa/manylinux2014_aarch64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM.manylinux2014 /root/ + +ENV PATH /opt/rh/devtoolset-10/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/devtoolset-10/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/devtoolset-10/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/devtoolset-10/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build centos-release-scl && \ + yum install -y devtoolset-10 && \ + export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/opt/rh/devtoolset-10/root/usr --disable-shared --libdir=/opt/rh/devtoolset-10/root/usr/lib64" && \ + curl -s -L -O --remote-name-all \ + https://github.com/facebook/zstd/releases/download/v1.5.6/zstd-1.5.6.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-v3.31.2.tar.gz \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.12.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/llvm-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/lld-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/libunwind-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/cmake-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/third-party-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/clang-19.1.5.src.tar.xz && \ + sha256sum -c SHA256SUM.manylinux2014 && \ + gzip -dc zstd-1.5.6.tar.gz | tar -xf - && \ + gzip -dc cmake-v3.31.2.tar.gz | tar -xf - && \ + gzip -dc v1.12.1.tar.gz | tar -xf - && \ + xz -dc llvm-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc lld-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc clang-19.1.5.src.tar.xz | tar -xf - && \ + export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-10/root/usr LIBDIR=/opt/rh/devtoolset-10/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + cd zstd-1.5.6 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-10/root/usr/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ + ../ninja-1.12.1/configure.py --bootstrap \ + --with-python=/opt/python/cp311-cp311/bin/python && \ + cp -v ninja /opt/rh/devtoolset-10/root/usr/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-v3.31.2/configure --prefix=/opt/rh/devtoolset-10/root/usr \ + --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v llvm-19.1.5.src llvm && \ + mv -v lld-19.1.5.src lld && \ + mv -v libunwind-19.1.5.src libunwind && \ + mv -v cmake-19.1.5.src cmake && \ + mv -v third-party-19.1.5.src third-party && \ + mv -v clang-19.1.5.src clang && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-10/root/usr \ + -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="AArch64;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ + -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ + -DBUILD_SHARED_LIBS=OFF llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 new file mode 100644 index 00000000..6ba01f13 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +FROM quay.io/pypa/manylinux2014_x86_64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM.manylinux2014 /root/ + +ENV PATH /opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg centos-release-scl && \ + yum install -y devtoolset-11 && \ + export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/opt/rh/devtoolset-11/root/usr --disable-shared --libdir=/opt/rh/devtoolset-11/root/usr/lib64" && \ + curl -s -L -O --remote-name-all \ + https://github.com/facebook/zstd/releases/download/v1.5.6/zstd-1.5.6.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-v3.31.2.tar.gz \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.12.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/llvm-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/lld-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/libunwind-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/cmake-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/third-party-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/clang-19.1.5.src.tar.xz && \ + sha256sum -c SHA256SUM.manylinux2014 && \ + gzip -dc zstd-1.5.6.tar.gz | tar -xf - && \ + gzip -dc cmake-v3.31.2.tar.gz | tar -xf - && \ + gzip -dc v1.12.1.tar.gz | tar -xf - && \ + xz -dc llvm-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc lld-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc clang-19.1.5.src.tar.xz | tar -xf - && \ + export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-11/root/usr LIBDIR=/opt/rh/devtoolset-11/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + cd zstd-1.5.6 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ + ../ninja-1.12.1/configure.py --bootstrap \ + --with-python=/opt/python/cp311-cp311/bin/python && \ + cp -v ninja /opt/rh/devtoolset-11/root/usr/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-v3.31.2/configure --prefix=/opt/rh/devtoolset-11/root/usr \ + --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v llvm-19.1.5.src llvm && \ + mv -v lld-19.1.5.src lld && \ + mv -v libunwind-19.1.5.src libunwind && \ + mv -v cmake-19.1.5.src cmake && \ + mv -v third-party-19.1.5.src third-party && \ + mv -v clang-19.1.5.src clang && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-11/root/usr \ + -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="X86;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ + -DBUILD_SHARED_LIBS=OFF llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28-base b/utils/docker/Dockerfile.manylinux_2_28-base new file mode 100644 index 00000000..da7b1449 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux_2_28-base @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC +ARG BASE_IMAGE="quay.io/pypa/manylinux_2_28_x86_64" +FROM ${BASE_IMAGE} + +ADD SHA256SUM.manylinux_2_28 /root/ + +# See /opt/rh/gcc-toolset-13/enable +ENV PATH=/opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH=/opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH=/opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-13/root/usr/lib64:/opt/rh/gcc-toolset-13/root/usr/lib:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +ENV PKG_CONFIG_PATH=/opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +ARG LLVM_TARGETS LLVM_TRIPLE + +RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ + yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ + yum install -y gcc-toolset-13 && \ + export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/opt/rh/gcc-toolset-13/root/usr --disable-shared --libdir=/opt/rh/gcc-toolset-13/root/usr/lib64" && \ + curl -s -L -O --remote-name-all \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.12.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/llvm-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/lld-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/libunwind-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/cmake-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/third-party-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/clang-19.1.5.src.tar.xz && \ + sha256sum -c SHA256SUM.manylinux_2_28 && \ + gzip -dc v1.12.1.tar.gz | tar -xf - && \ + xz -dc llvm-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc lld-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc clang-19.1.5.src.tar.xz | tar -xf - && \ + mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ + ../ninja-1.12.1/configure.py --bootstrap \ + --with-python=/opt/python/cp311-cp311/bin/python && \ + cp -v ninja /opt/rh/gcc-toolset-13/root/usr/bin/ninja && cd - && rm -rf build && \ + mv -v llvm-19.1.5.src llvm && \ + mv -v lld-19.1.5.src lld && \ + mv -v libunwind-19.1.5.src libunwind && \ + mv -v cmake-19.1.5.src cmake && \ + mv -v third-party-19.1.5.src third-party && \ + mv -v clang-19.1.5.src clang && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/gcc-toolset-13/root/usr \ + -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="${LLVM_TARGETS}" -DLLVM_ENABLE_PROJECTS="lld;clang" \ + -DLLVM_DEFAULT_TARGET_TRIPLE="${LLVM_TRIPLE}" \ + -DBUILD_SHARED_LIBS=OFF llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps new file mode 100644 index 00000000..b08f89de --- /dev/null +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -0,0 +1,48 @@ +ARG BASE_IMAGE="wasmedge/wasmedge:manylinux_2_28_x86_64" +FROM ${BASE_IMAGE} AS base + +WORKDIR /root + +### deps for x86_64 ### +FROM base AS deps-amd64 +RUN cd && (yum check-update || true) && \ + yum install -y wget unzip zlib-devel zlib-static elfutils-libelf-devel + +COPY wasi-nn/install-pytorch.sh . +ENV PYTORCH_VERSION="2.5.1" +ENV PYTORCH_INSTALL_TO="/root" +ENV Torch_DIR="/root/libtorch" +RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] + +### deps for aarch64 ### +FROM base AS deps-arm64 +RUN cd && (yum check-update || true) && \ + yum install -y wget unzip zlib-devel zlib-static + +### deps for all ### +FROM deps-${TARGETARCH} AS final + +COPY opencvmini/install-opencvmini.sh . +ENV OPENCV_VERSION="4.8.0" +RUN [ "/bin/bash", "install-opencvmini.sh" ] + +COPY wasi-crypto/build-openssl.sh . +ENV OpenSSL_DIR="/root/openssl-1.1.1n/openssl" +RUN [ "/bin/bash", "build-openssl.sh" ] + +COPY ffmpeg/install-ffmpeg-v7.1.1.sh . +RUN [ "/bin/bash", "install-ffmpeg-v7.1.1.sh" ] +ENV PKG_CONFIG_PATH=/root/FFmpeg-n7.1.1/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH=/root/FFmpeg-n7.1.1/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +ENV OPENVINO_VERSION="2025.0.0" +ENV OPENVINO_YEAR="2025" + +COPY wasi-nn/install-onnxruntime.sh . +RUN [ "/bin/bash", "install-onnxruntime.sh" ] + +COPY wasi-nn/libpiper.patch . +COPY wasi-nn/install-libpiper.sh . +RUN [ "/bin/bash", "install-libpiper.sh" ] + +RUN yum clean all diff --git a/utils/docker/Dockerfile.release b/utils/docker/Dockerfile.release new file mode 100644 index 00000000..ab05caf5 --- /dev/null +++ b/utils/docker/Dockerfile.release @@ -0,0 +1,13 @@ +FROM ubuntu:22.04 +ARG VERSION + +RUN apt-get update && apt-get install -y netbase +ADD WasmEdge-$VERSION-Linux.tar.gz /tmp/ +RUN cp -rf /tmp/WasmEdge-$VERSION-Linux/bin/* /usr/local/bin && \ + cp -rf /tmp/WasmEdge-$VERSION-Linux/lib64/* /usr/local/lib && \ + cp -rf /tmp/WasmEdge-$VERSION-Linux/include/* /usr/local/include && \ + ldconfig /usr/local/lib +RUN rm -rf /tmp/WasmEdge-$VERSION-Linux + +WORKDIR /app +CMD ["/usr/local/bin/wasmedge"] diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base new file mode 100644 index 00000000..4d4cb463 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-base @@ -0,0 +1,79 @@ +ARG UBUNTU_VER=22 +FROM ubuntu:${UBUNTU_VER}.04 AS base + +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y \ + curl \ + dpkg-dev \ + g++ \ + gcc \ + git \ + ninja-build \ + software-properties-common \ + wget \ + zlib1g-dev + +### deps for ubuntu 20.04 ### +FROM base AS deps-20 + +RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh +RUN apt-get install -y cmake + +RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - \ + && echo "deb http://apt.llvm.org/focal llvm-toolchain-focal-18 main" | tee /etc/apt/sources.list.d/llvm.list + +RUN apt-get install -y \ + llvm-18-dev \ + liblld-18-dev \ + libpolly-18-dev \ + clang-18 + +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-18 100 && \ + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-18 100 && \ + update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-18 100 + +ENV CC=/usr/bin/clang-18 +ENV CXX=/usr/bin/clang++-18 + +### deps for ubuntu 22.04 ### +FROM base AS deps-22 + +RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh +RUN apt-get install -y cmake + +RUN apt-get install -y \ + llvm-15-dev \ + liblld-15-dev \ + clang-15 + +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-15 100 && \ + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-15 100 && \ + update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-15 100 + +ENV CC=/usr/bin/clang-15 +ENV CXX=/usr/bin/clang++-15 + +### deps for ubuntu 24.04 ### +FROM base AS deps-24 + +RUN apt-get install -y cmake + +RUN apt-get install -y \ + llvm-18-dev \ + liblld-18-dev \ + libpolly-18-dev \ + clang-18 + +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-18 100 && \ + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-18 100 && \ + update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-18 100 + +ENV CC=/usr/bin/clang-18 +ENV CXX=/usr/bin/clang++-18 + +### cleanup +FROM deps-${UBUNTU_VER} AS clean-apt + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.ubuntu-cuda b/utils/docker/Dockerfile.ubuntu-cuda new file mode 100644 index 00000000..f10b8b5f --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-cuda @@ -0,0 +1,24 @@ +ARG BASE_IMAGE=wasmedge/wasmedge:latest +FROM ${BASE_IMAGE} AS base + +WORKDIR /root + +ARG CUDA_KEYRING=cuda-keyring_1.1-1_all.deb +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/${CUDA_KEYRING} && \ + dpkg -i ${CUDA_KEYRING} && \ + rm -f ${CUDA_KEYRING} + +ARG NVCC_VER=12-0 +RUN apt-get update && \ + apt-get install -y \ + cuda-nvcc-${NVCC_VER} \ + libcublas-dev-${NVCC_VER} \ + pkg-config \ + unzip + +ENV CXXFLAGS="-Wno-error" + +### cleanup +FROM base AS clean-apt + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.ubuntu-env b/utils/docker/Dockerfile.ubuntu-env new file mode 100644 index 00000000..dde599e0 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-env @@ -0,0 +1,15 @@ +ARG BASE_IMAGE=wasmedge/wasmedge:latest +ARG TOOLCHAIN=clang +FROM ${BASE_IMAGE} AS base + +### env for clang +FROM base AS deps-clang + +### env for gcc +FROM base AS deps-gcc + +ENV CC=/usr/bin/gcc +ENV CXX=/usr/bin/g++ + +### final +FROM deps-${TOOLCHAIN} AS final diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps new file mode 100644 index 00000000..a85945d3 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -0,0 +1,98 @@ +ARG BASE_IMAGE=wasmedge/wasmedge:latest +ARG UBUNTU_VER=20 +FROM ${BASE_IMAGE} AS base + +RUN apt-get update && \ + apt-get install -y \ + cargo \ + libelf-dev \ + libomp-dev \ + libopenblas-dev \ + libssl-dev \ + pkg-config \ + unzip \ + yasm + +RUN apt-get install -y \ + libgrpc++-dev \ + libgrpc-dev \ + protobuf-compiler-grpc + +# FFmpeg 6.1 (ubuntu 24.04) +FROM base AS deps-24 + +RUN apt-get install -y \ + libavcodec-dev \ + libavdevice-dev \ + libavfilter-dev \ + libavformat-dev \ + libavutil-dev \ + libswresample-dev \ + libswscale-dev + +# FFmpeg 7.1.1 (ubuntu 20.04, 22.04) +FROM base AS deps-20 + +WORKDIR /usr/local + +COPY ffmpeg/install-ffmpeg-v7.1.1.sh . +RUN [ "/bin/bash", "install-ffmpeg-v7.1.1.sh" ] +ENV PKG_CONFIG_PATH=/usr/local/FFmpeg-n7.1.1/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH=/usr/local/FFmpeg-n7.1.1/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +FROM deps-20 AS deps-22 + +# Other dependencies +FROM deps-${UBUNTU_VER} AS deps-all + +WORKDIR /root + +COPY opencvmini/install-opencvmini.sh . +ENV OPENCV_VERSION="4.8.0" +RUN [ "/bin/bash", "install-opencvmini.sh" ] + +COPY wasi-nn/install-pytorch.sh . +ENV PYTORCH_VERSION="2.5.1" +ENV PYTORCH_INSTALL_TO="/usr/local" +ENV Torch_DIR="/usr/local/libtorch" +RUN [ "/bin/bash", "install-pytorch.sh" ] + +ARG UBUNTU_VER + +COPY wasi-nn/install-openvino.sh . +COPY wasi-nn/install-openvino-genai.sh . +ENV OPENVINO_UBUNTU_VERSION=${UBUNTU_VER} +ENV OPENVINO_VERSION="2025.0.0" +ENV OPENVINO_YEAR="2025" +ENV OPENVINOGEN_VERSION="2025.0.0.0" +ENV OPENVINOGEN_YEAR="2025.0" +ENV OPENVINOGEN_DIRNAME="openvino_genai" +RUN [ "/bin/bash", "install-openvino.sh" ] +RUN [ "/bin/bash", "install-openvino-genai.sh" ] +RUN [ "ls", "-al" ] +RUN [ "/bin/bash", "-c", "echo \"source ./openvino_genai/setupvars.sh\" >> .bashrc" ] + +COPY wasi-nn/install-onnxruntime.sh . +RUN [ "/bin/bash", "install-onnxruntime.sh" ] + +COPY wasi-nn/libpiper.patch . +COPY wasi-nn/install-libpiper.sh . +RUN [ "/bin/bash", "install-libpiper.sh" ] + +COPY wasi-nn/install-chattts.sh . +RUN [ "/bin/bash", "install-chattts.sh" ] + +### cleanup +FROM deps-all AS clean-apt + +RUN rm -f \ + install-opencvmini.sh \ + install-ffmpeg-v7.1.1.sh \ + install-pytorch.sh \ + install-openvino.sh \ + install-onnxruntime.sh \ + install-openvino-genai.sh \ + install-chattts.sh \ + install-libpiper.sh + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.ubuntu2004_x86_64 b/utils/docker/Dockerfile.ubuntu2004_x86_64 new file mode 100644 index 00000000..f71ed3c8 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu2004_x86_64 @@ -0,0 +1,24 @@ +FROM ubuntu:20.04 + +MAINTAINER hydai hydai@secondstate.io +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && apt upgrade -y \ + && apt install -y \ + software-properties-common \ + wget \ + cmake \ + ninja-build \ + curl \ + git \ + dpkg-dev \ + llvm-12-dev \ + liblld-12-dev \ + gcc \ + rpm \ + g++ + +RUN rm -rf /var/lib/apt/lists/* + +ENV CC=gcc +ENV CXX=g++ diff --git a/utils/docker/Dockerfile.ubuntu2104_armv7l b/utils/docker/Dockerfile.ubuntu2104_armv7l new file mode 100644 index 00000000..f1439243 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu2104_armv7l @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +FROM arm32v7/ubuntu:hirsute + +MAINTAINER hydai hydai@secondstate.io +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && apt upgrade -y \ + && apt install -y \ + build-essential \ + cmake \ + curl \ + dpkg-dev \ + gcc \ + gcc-multilib \ + g++ \ + g++-multilib \ + git \ + llvm-12-dev \ + liblld-12-dev \ + libssl-dev \ + ninja-build \ + software-properties-common \ + python3 \ + rpm \ + wget \ + xz-utils + +# Build CMake from source to avoid compiler_id_detection failures when using QEMU user-mode emulation +# See: https://gitlab.kitware.com/cmake/cmake/-/issues/20568 + +#RUN wget https://github.com/Kitware/CMake/releases/download/v3.29.3/cmake-3.29.3.tar.gz --no-check-certificate && \ +# tar zxvf cmake-3.29.3.tar.gz && \ +# cd cmake-3.29.3 && \ +# ./configure && \ +# make install -j$(nproc) && \ +# cd .. && rm -rf cmake-3.29.3 + +RUN rm -rf /var/lib/apt/lists/* + +ENV CC=gcc +ENV CXX=g++ diff --git a/utils/docker/SHA256SUM.manylinux2014 b/utils/docker/SHA256SUM.manylinux2014 new file mode 100644 index 00000000..63aa74fb --- /dev/null +++ b/utils/docker/SHA256SUM.manylinux2014 @@ -0,0 +1,9 @@ +e7dfc8050407b5cc564c1c1afe19517255c9229cccd886dbd5bac9b652828d85 clang-19.1.5.src.tar.xz +a08ae477571fd5e929c27d3d0d28c6168d58dd00b6354c2de3266ae0d86ad44f cmake-19.1.5.src.tar.xz +0019dfc4b32d63c1392aa264aed2253c1e0c2fb09216f8e2cc269bbfb8bb49b5 cmake-v3.31.2.tar.gz +997b493fb604e5e2c5b11c765a4c42b37acf00a4d6e8a14f8108d5c1051d760f libunwind-19.1.5.src.tar.xz +f71835d49461a15c283aa9030a854abfd7c651a685d711a67158644b043f6f14 lld-19.1.5.src.tar.xz +7d71635948e4da1814ce8e15ec45399e4094a5442e86d352c96ded0f2b3171b6 llvm-19.1.5.src.tar.xz +22b352c35b034a4ab3f2b852b6a2602a4da8971abe459080450d9e3462550d1d third-party-19.1.5.src.tar.xz +821bdff48a3f683bc4bb3b6f0b5fe7b2d647cf65d52aeb63328c91a6c6df285a v1.12.1.tar.gz +8c29e06cf42aacc1eafc4077ae2ec6c6fcb96a626157e0593d5e82a34fd403c1 zstd-1.5.6.tar.gz diff --git a/utils/docker/SHA256SUM.manylinux_2_28 b/utils/docker/SHA256SUM.manylinux_2_28 new file mode 100644 index 00000000..136490f8 --- /dev/null +++ b/utils/docker/SHA256SUM.manylinux_2_28 @@ -0,0 +1,7 @@ +e7dfc8050407b5cc564c1c1afe19517255c9229cccd886dbd5bac9b652828d85 clang-19.1.5.src.tar.xz +a08ae477571fd5e929c27d3d0d28c6168d58dd00b6354c2de3266ae0d86ad44f cmake-19.1.5.src.tar.xz +997b493fb604e5e2c5b11c765a4c42b37acf00a4d6e8a14f8108d5c1051d760f libunwind-19.1.5.src.tar.xz +f71835d49461a15c283aa9030a854abfd7c651a685d711a67158644b043f6f14 lld-19.1.5.src.tar.xz +7d71635948e4da1814ce8e15ec45399e4094a5442e86d352c96ded0f2b3171b6 llvm-19.1.5.src.tar.xz +22b352c35b034a4ab3f2b852b6a2602a4da8971abe459080450d9e3462550d1d third-party-19.1.5.src.tar.xz +821bdff48a3f683bc4bb3b6f0b5fe7b2d647cf65d52aeb63328c91a6c6df285a v1.12.1.tar.gz diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh new file mode 100755 index 00000000..5e4affab --- /dev/null +++ b/utils/docker/build-manylinux.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +git config --global --add safe.directory $(pwd) + +CMAKE_BUILD_TYPE="Release" +IS_BUILD_TARGET=true +IS_NINJA=true +CMAKE_OPTS="" + +for i in "$@"; do + case $i in + --release|--Release) + CMAKE_BUILD_TYPE="Release" + shift + ;; + --debug|--Debug) + CMAKE_BUILD_TYPE="Debug" + shift + ;; + --not-build) + IS_BUILD_TARGET=false + shift + ;; + --not-ninja) + IS_NINJA=false + shift + ;; + *) + CMAKE_OPTS="${CMAKE_OPTS} $i" + shift + ;; + esac +done + +if $IS_NINJA; then + if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" ${CMAKE_OPTS} .; then + echo === CMakeOutput.log === + cat build/CMakeFiles/CMakeOutput.log + echo === CMakeError.log === + cat build/CMakeFiles/CMakeError.log + exit 1 + fi +else + rm -rf build + mkdir build + cd build + if ! cmake -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" ${CMAKE_OPTS} ..; then + cd .. + echo === CMakeOutput.log === + cat build/CMakeFiles/CMakeOutput.log + echo === CMakeError.log === + cat build/CMakeFiles/CMakeError.log + exit 1 + fi + cd .. +fi + +if ${IS_BUILD_TARGET}; then + cmake --build build + cmake --build build --target package +fi diff --git a/utils/docker/docker-bake.alpine-base.hcl b/utils/docker/docker-bake.alpine-base.hcl new file mode 100644 index 00000000..e492c9ef --- /dev/null +++ b/utils/docker/docker-bake.alpine-base.hcl @@ -0,0 +1,12 @@ +group "default" { + targets = ["alpine-base"] +} + +target "alpine-base" { + dockerfile = "Dockerfile.alpine-base" + context = "utils/docker" + # Build for both major architectures at once + platforms = ["linux/amd64", "linux/arm64"] + # Tag it clearly so we can use it later + tags = ["wasmedge/wasmedge:alpine-base-3.23"] +} diff --git a/utils/docker/docker-bake.alpine-static.hcl b/utils/docker/docker-bake.alpine-static.hcl new file mode 100644 index 00000000..91938349 --- /dev/null +++ b/utils/docker/docker-bake.alpine-static.hcl @@ -0,0 +1,17 @@ +group "default" { + targets = ["cross"] +} + +target "base" { + dockerfile = "./utils/docker/Dockerfile.alpine-static" + context = "." + output = ["build"] +} + +target "cross" { + inherits = ["base"] + platforms = [ + "linux/amd64", + "linux/arm64" + ] +} diff --git a/utils/docker/docker-bake.ci-image-base.hcl b/utils/docker/docker-bake.ci-image-base.hcl new file mode 100644 index 00000000..b98e75b6 --- /dev/null +++ b/utils/docker/docker-bake.ci-image-base.hcl @@ -0,0 +1,26 @@ +group "default" { + targets = [ + "x86_64", + "aarch64" + ] +} + +target "base" { + dockerfile = "Dockerfile.ci-image-base" + context = "./utils/docker" +} + +target "x86_64" { + inherits = ["base"] + platforms = ["linux/amd64"] + tags = [ + "wasmedge/wasmedge:ci-image-base", + "wasmedge/wasmedge:ci-image-base_x86_64" + ] +} + +target "aarch64" { + inherits = ["base"] + platforms = ["linux/arm64"] + tags = ["wasmedge/wasmedge:ci-image-base_aarch64"] +} diff --git a/utils/docker/docker-bake.debian-static.hcl b/utils/docker/docker-bake.debian-static.hcl new file mode 100644 index 00000000..5558e95c --- /dev/null +++ b/utils/docker/docker-bake.debian-static.hcl @@ -0,0 +1,17 @@ +group "default" { + targets = ["cross"] +} + +target "base" { + dockerfile = "./utils/docker/Dockerfile.debian-static" + context = "." + output = ["build"] +} + +target "cross" { + inherits = ["base"] + platforms = [ + "linux/amd64", + "linux/arm64" + ] +} diff --git a/utils/docker/docker-bake.manylinux.hcl b/utils/docker/docker-bake.manylinux.hcl new file mode 100644 index 00000000..0d122ffc --- /dev/null +++ b/utils/docker/docker-bake.manylinux.hcl @@ -0,0 +1,51 @@ +target "base" { + dockerfile = "Dockerfile.manylinux_2_28-base" + context = "./utils/docker" +} + +target "plugins-base" { + dockerfile = "./docker/Dockerfile.manylinux_2_28-plugins-deps" + context = "./utils" +} + +target "x86_64" { + inherits = ["base"] + platforms = ["linux/amd64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_x86_64"] + args = { + LLVM_TARGETS = "X86;BPF", + LLVM_TRIPLE = "x86_64-pc-linux-gnu" + } +} + +target "x86_64-plugins" { + inherits = ["plugins-base"] + platforms = ["linux/amd64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_x86_64-plugins-deps"] + contexts = { + "wasmedge/wasmedge:manylinux_2_28_x86_64"= "target:x86_64" + } +} + +target "aarch64" { + inherits = ["base"] + platforms = ["linux/arm64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_aarch64"] + args = { + BASE_IMAGE = "quay.io/pypa/manylinux_2_28_aarch64", + LLVM_TARGETS = "AArch64;BPF", + LLVM_TRIPLE = "aarch64-redhat-linux-gnu" + } +} + +target "aarch64-plugins" { + inherits = ["plugins-base"] + platforms = ["linux/arm64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_aarch64-plugins-deps"] + contexts = { + "wasmedge/wasmedge:manylinux_2_28_aarch64" = "target:aarch64" + } + args = { + BASE_IMAGE = "wasmedge/wasmedge:manylinux_2_28_aarch64" + } +} diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl new file mode 100644 index 00000000..b2012d5f --- /dev/null +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -0,0 +1,178 @@ +group "default" { + targets = [ + "cuda", + "final" + ] +} + +group "latest" { + targets = [ + "base-2404-clang", + ] +} + +group "focal" { + targets = [ + "base-2004-clang", + "base-2004-gcc", + "plugins-2004-clang", + "plugins-2004-gcc", + ] +} + +group "jammy" { + targets = [ + "base-2204-clang", + "base-2204-gcc", + "plugins-2204-clang", + "plugins-2204-gcc", + ] +} + +group "noble" { + targets = [ + "base-2404-clang", + "base-2404-gcc", + "plugins-2404-clang", + "plugins-2404-gcc", + ] +} + +function "no-dot" { + params = [ubuntu] + result = replace(ubuntu, ".", "") +} + +function "major" { + params = [ubuntu] + result = regex("^[[:digit:]]+", ubuntu) +} + +function "tags-latest" { + params = [target, ubuntu, toolchain] + result = target == "base" && ubuntu == "24.04" && toolchain == "clang" ? "latest" : "" +} + +function "tags-latest-backports" { + params = [target, ubuntu, toolchain] + result = ubuntu == "24.04" ? join("-", compact([ + "ubuntu", + "build", + toolchain, + target == "plugins" ? "plugins-deps" : "", + ])) : "" +} + +function "tags-backports" { + params = [target, ubuntu, toolchain] + result = join("-", compact([ + "ubuntu", + ubuntu, + "build", + toolchain, + target == "plugins" ? "plugins-deps" : "", + ])) +} + +function "tags-simplified" { + params = [target, ubuntu, toolchain] + result = toolchain == "clang" ? join("-", compact([ + "ubuntu", + ubuntu, + target == "plugins" ? "plugins" : "", + ])) : "" +} + +function "tags" { + params = [target, ubuntu, toolchain] + result = [for tag in compact([ + tags-latest(target, ubuntu, toolchain), + tags-latest-backports(target, ubuntu, toolchain), + tags-backports(target, ubuntu, toolchain), + tags-simplified(target, ubuntu, toolchain), + ]) : "wasmedge/wasmedge:${tag}"] +} + +target "base" { + dockerfile = "Dockerfile.ubuntu-base" + context = "./utils/docker" + + matrix = { + ubuntu = ["20.04", "22.04", "24.04"] + } + + name = "base-${no-dot(ubuntu)}" + tags = ["local/tmp:base-${ubuntu}"] + args = { + UBUNTU_VER = major(ubuntu) + } +} + +target "plugins" { + dockerfile = "./docker/Dockerfile.ubuntu-plugins-deps" + context = "./utils" + + matrix = { + ubuntu = ["20.04", "22.04", "24.04"] + } + + name = "plugins-${no-dot(ubuntu)}" + contexts = { + "local/tmp:base-${ubuntu}" = "target:base-${no-dot(ubuntu)}" + } + tags = ["local/tmp:plugins-${ubuntu}"] + args = { + BASE_IMAGE = "local/tmp:base-${ubuntu}" + UBUNTU_VER = major(ubuntu) + } +} + +target "final" { + matrix = { + parent = ["base", "plugins"] + ubuntu = ["20.04", "22.04", "24.04"] + toolchain = ["clang", "gcc"] + } + + dockerfile = "Dockerfile.ubuntu-env" + context = "./utils/docker" + + name = "${parent}-${no-dot(ubuntu)}-${toolchain}" + contexts = { + "local/tmp:${parent}-${ubuntu}" = "target:${parent}-${no-dot(ubuntu)}" + } + tags = tags(parent, ubuntu, toolchain) + args = { + BASE_IMAGE = "local/tmp:${parent}-${ubuntu}" + TOOLCHAIN = toolchain + } +} + +target "cuda" { + dockerfile = "Dockerfile.ubuntu-cuda" + context = "./utils/docker" + + matrix = { + cuda = ["11.3", "12.0"] + } + + name = "base-2004-gcc-cuda${major(cuda)}" + contexts = { + "wasmedge/wasmedge:ubuntu-20.04-build-gcc" = "target:base-2004-gcc" + } + tags = ["wasmedge/wasmedge:ubuntu-20.04-build-gcc-cuda${major(cuda)}"] + args = { + BASE_IMAGE = "wasmedge/wasmedge:ubuntu-20.04-build-gcc" + NVCC_VER = replace(cuda, ".", "-") + } +} + +# TODO: Refactor with multi-arch image +target "base-2004-clang-aarch64" { + inherits = ["base-2004"] + contexts = { + "local/tmp:base-2004" = "target:base-2004" + } + tags = [for tag in tags("base", "20.04", "clang") : "${tag}-aarch64"] + platforms = ["linux/arm64"] +} diff --git a/utils/ffmpeg/download-ffmpeg-sample-video.sh b/utils/ffmpeg/download-ffmpeg-sample-video.sh new file mode 100644 index 00000000..bae80b6e --- /dev/null +++ b/utils/ffmpeg/download-ffmpeg-sample-video.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +# The video below is sourced from an ffmpeg-libav-tutorial repository. +# Source: https://github.com/leandromoreira/ffmpeg-libav-tutorial/blob/master/LICENSE. +TODIR=$1 +SAMPLE_VIDEO=https://raw.githubusercontent.com/Hrushi20/rust-ffmpeg/master/assets/bunny.mp4 +if [[ $# -eq 0 ]]; then + TODIR=. +fi +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi + +if [ ! -f $TODIR/sample_video.mp4 ]; then + curl -sL $SAMPLE_VIDEO -o $TODIR/sample_video.mp4 + cp $TODIR/sample_video.mp4 $TODIR/dummy.mp4 # Dummy file to manipulate and run tests on file. +fi diff --git a/utils/ffmpeg/install-ffmpeg-v7.1.1.sh b/utils/ffmpeg/install-ffmpeg-v7.1.1.sh new file mode 100644 index 00000000..4e049948 --- /dev/null +++ b/utils/ffmpeg/install-ffmpeg-v7.1.1.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -e + +curl -sL https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n7.1.1.zip -o ffmpeg.zip + +unzip ffmpeg.zip + +mkdir -p FFmpeg-n7.1.1/output +cd FFmpeg-n7.1.1 +./configure --prefix=$(pwd)/output --enable-gpl --enable-nonfree --enable-shared --disable-static +make && make install +cd .. + +rm -rf ffmpeg.zip diff --git a/utils/opencvmini/install-opencvmini.sh b/utils/opencvmini/install-opencvmini.sh new file mode 100644 index 00000000..43afef93 --- /dev/null +++ b/utils/opencvmini/install-opencvmini.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC +OPENCV_VERSION=${OPENCV_VERSION:-4.8.0} + +wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/${OPENCV_VERSION}.zip + +unzip opencv.zip +mv opencv-${OPENCV_VERSION} opencv + +mkdir -p opencv/build && cd opencv/build +# Configure +cmake -GNinja .. +# Build +cmake --build . +# Install to system +cmake --install . + +cd - && rm -rf opencv opencv.zip diff --git a/utils/wasi-crypto/build-openssl.sh b/utils/wasi-crypto/build-openssl.sh new file mode 100755 index 00000000..3adef636 --- /dev/null +++ b/utils/wasi-crypto/build-openssl.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +echo "Building OpenSSL for wasi-crypto..." +# Get OpenSSL source +curl -s -L -O --remote-name-all https://www.openssl.org/source/openssl-1.1.1n.tar.gz +echo "40dceb51a4f6a5275bde0e6bf20ef4b91bfc32ed57c0552e2e8e15463372b17a openssl-1.1.1n.tar.gz" | sha256sum -c +tar -xf openssl-1.1.1n.tar.gz +cd ./openssl-1.1.1n +# Configuring OpenSSL requires newer Perl. +curl -s -L -O --remote-name-all https://www.cpan.org/src/5.0/perl-5.34.0.tar.gz +tar -xf perl-5.34.0.tar.gz +cd perl-5.34.0 +mkdir localperl +./Configure -des -Dprefix=$(pwd)/localperl/ +make -j +# too long! +# make test +make install +export PATH="$(pwd)/localperl/bin/:$PATH" +cd .. +# Configure by previous perl +mkdir openssl +./perl-5.34.0/localperl/bin/perl ./config --prefix=$(pwd)/openssl --openssldir=$(pwd)/openssl +make -j +make test +make install +cd .. diff --git a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh new file mode 100755 index 00000000..0256ab83 --- /dev/null +++ b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then + CMAKE_BUILD_TYPE=Release +fi + +ldconfig +git config --global --add safe.directory $(pwd) +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="OpenVINO" .; then + echo === CMakeOutput.log === + cat build/CMakeFiles/CMakeOutput.log + echo === CMakeError.log === + cat build/CMakeFiles/CMakeError.log + exit 1 +fi +cmake --build build diff --git a/utils/wasi-nn/install-chattts.sh b/utils/wasi-nn/install-chattts.sh new file mode 100644 index 00000000..6361cae6 --- /dev/null +++ b/utils/wasi-nn/install-chattts.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +set -e + +apt-get update +apt-get -y install python3 python3-dev + +# Use latest pip +python3 -m venv chattts_venv +source chattts_venv/bin/activate +pip install --upgrade pip + +# Install PyTorch CPU version to save space +pip --python /usr/bin/python3 install --break-system-packages --index-url https://download.pytorch.org/whl/cpu 'torch<=2.6.0' 'torchaudio<=2.6.0' +pip --python /usr/bin/python3 install --break-system-packages chattts==0.2.4, transformers==4.46.3 + +# Remove wheel cache +pip --python /usr/bin/python3 cache purge + +# Clean up +deactivate +rm -rf chattts_venv diff --git a/utils/wasi-nn/install-libpiper.sh b/utils/wasi-nn/install-libpiper.sh new file mode 100755 index 00000000..4a4ee475 --- /dev/null +++ b/utils/wasi-nn/install-libpiper.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2026 Second State INC + +set -e + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +PATCH_PATH="${SCRIPT_DIR}/libpiper.patch" + +if [ ! -f "${PATCH_PATH}" ]; then + echo "Error: libpiper.patch not found at ${PATCH_PATH}" + exit 1 +fi + +PIPER_REPO="https://github.com/OHF-Voice/piper1-gpl.git" +PIPER_COMMIT="32b95f8c1f0dc0ce27a6acd1143de331f61af777" +PIPER_INSTALL_TO="/usr/local" + +case "$(uname -m)" in + 'x86_64') ;; + 'aarch64') ;; + *) + echo "Unsupported architecture for libpiper: $(uname -m)" >&2 + exit 1 + ;; +esac + +rm -rf piper-source +git clone --depth 1 "${PIPER_REPO}" piper-source +cd piper-source +git fetch --depth 1 origin "${PIPER_COMMIT}" +git checkout FETCH_HEAD + +cp "${PATCH_PATH}" . +patch -p1 < libpiper.patch +cd libpiper + +cmake -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="${PIPER_INSTALL_TO}" \ + -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON +cmake --build build --parallel $(nproc) +cmake --install build + +echo "Copying espeak-ng static libraries to ${PIPER_INSTALL_TO}/lib..." + +if [ -f "build/espeak_ng-install/lib/libespeak-ng.a" ]; then + cp "build/espeak_ng-install/lib/libespeak-ng.a" "${PIPER_INSTALL_TO}/lib/" +else + find build -name "libespeak-ng.a" -exec cp {} "${PIPER_INSTALL_TO}/lib/" \; -quit +fi + +find build -name "libucd.a" -exec cp {} "${PIPER_INSTALL_TO}/lib/" \; -quit + +if [ ! -f "${PIPER_INSTALL_TO}/lib/libespeak-ng.a" ]; then + echo "Error: Failed to install libespeak-ng.a" + exit 1 +fi + +if [ ! -f "${PIPER_INSTALL_TO}/lib/libucd.a" ]; then + echo "Error: Failed to install libucd.a" + exit 1 +fi + +echo "Installing espeak-ng-data to ${PIPER_INSTALL_TO}/share..." + +if [ -d "build/espeak_ng-install/share/espeak-ng-data" ]; then + mkdir -p "${PIPER_INSTALL_TO}/share" + cp -r "build/espeak_ng-install/share/espeak-ng-data" "${PIPER_INSTALL_TO}/share/" + echo "Espeak-ng-data installed successfully." +else + echo "Error: espeak-ng-data directory not found in build tree!" + exit 1 +fi + +cd ../.. +rm -rf piper-source + +ldconfig diff --git a/utils/wasi-nn/install-onnxruntime.sh b/utils/wasi-nn/install-onnxruntime.sh new file mode 100644 index 00000000..dad1fb56 --- /dev/null +++ b/utils/wasi-nn/install-onnxruntime.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +set -e + +case "$(uname -m)" in + 'x86_64') ARCH='x64' ;; + 'aarch64') ARCH='aarch64' ;; + *) + echo 'Cannot determine architecture for onnxruntime' >&2 + exit 1 + ;; +esac + +: ${ONNXRUNTIME_VERSION:=1.14.1} + +ONNXRUNTIME_NAME="onnxruntime-linux-${ARCH}-${ONNXRUNTIME_VERSION}" +ONNXRUNTIME_TGZ="${ONNXRUNTIME_NAME}.tgz" + +curl -LO "https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/${ONNXRUNTIME_TGZ}" +tar zxf "${ONNXRUNTIME_TGZ}" +mv "${ONNXRUNTIME_NAME}/include/"* /usr/local/include/ +mv "${ONNXRUNTIME_NAME}/lib/"* /usr/local/lib/ +rm -rf "${ONNXRUNTIME_TGZ}" "${ONNXRUNTIME_NAME}" + +ldconfig diff --git a/utils/wasi-nn/install-openvino-genai.sh b/utils/wasi-nn/install-openvino-genai.sh new file mode 100644 index 00000000..2be3691e --- /dev/null +++ b/utils/wasi-nn/install-openvino-genai.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC +set -e + +if [[ ! -v "${OPENVINOGEN_VERSION}" ]]; then + OPENVINOGEN_VERSION="2025.0.0.0" +fi +if [[ ! -v "${OPENVINOGEN_YEAR}" ]]; then + OPENVINOGEN_YEAR="2025.0" +fi +if [[ ! -v "${OPENVINOGEN_DIRNAME}" ]]; then + OPENVINOGEN_DIRNAME="openvino_genai" +fi + +if [[ ! -v "${OPENVINO_UBUNTU_VERSION}" ]]; then + source /etc/os-release + OPENVINO_UBUNTU_VERSION="${VERSION_ID::2}" +fi + +UBUNTU_VERSION="ubuntu${OPENVINO_UBUNTU_VERSION:-20}" +OPENVINOGEN_TGZ_NAME="openvino_genai_${UBUNTU_VERSION}_${OPENVINOGEN_VERSION}_x86_64" + + +echo "Installing OpenVINO GenAI with version ${OPENVINOGEN_VERSION}" +curl -L https://storage.openvinotoolkit.org/repositories/openvino_genai/packages/${OPENVINOGEN_YEAR}/linux/${OPENVINOGEN_TGZ_NAME}.tar.gz --output ${OPENVINOGEN_TGZ_NAME}.tgz +tar -xf ${OPENVINOGEN_TGZ_NAME}.tgz +mv ${OPENVINOGEN_TGZ_NAME} $OPENVINOGEN_DIRNAME +./${OPENVINOGEN_DIRNAME}/install_dependencies/install_openvino_dependencies.sh -y + +rm ${OPENVINOGEN_TGZ_NAME}.tgz + +echo "OpenVINO GenAI installed at `pwd`/$OPENVINOGEN_DIRNAME" +echo "Please source the setupvars.sh script in the OpenVINO GenAI directory to use the OpenVINO GenAI tools." +echo "# source $PWD/$OPENVINOGEN_DIRNAME/setupvars.sh" diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh new file mode 100755 index 00000000..432f1ac8 --- /dev/null +++ b/utils/wasi-nn/install-openvino.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC +set -e + +if [[ ! -v "${OPENVINO_VERSION}" ]]; then + OPENVINO_VERSION="2025.0.0" +fi +if [[ ! -v "${OPENVINO_YEAR}" ]]; then + OPENVINO_YEAR="2025" +fi + +echo "Installing OpenVINO with version ${OPENVINO_VERSION}" +KEY_FILE=GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB +wget https://apt.repos.intel.com/intel-gpg-keys/$KEY_FILE && \ + apt-key add $KEY_FILE && \ + rm -f $KEY_FILE +UBUNTU_VERSION="ubuntu${OPENVINO_UBUNTU_VERSION:-20}" +echo "deb https://apt.repos.intel.com/openvino/$OPENVINO_YEAR ${UBUNTU_VERSION} main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list +apt update +apt-get -y install openvino-$OPENVINO_VERSION +ldconfig diff --git a/utils/wasi-nn/install-pytorch.sh b/utils/wasi-nn/install-pytorch.sh new file mode 100755 index 00000000..55f77ac3 --- /dev/null +++ b/utils/wasi-nn/install-pytorch.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +if [[ ! -n ${PYTORCH_VERSION} ]]; then + PYTORCH_VERSION="2.5.1" +fi + +if [[ ! -n ${PYTORCH_INSTALL_TO} ]]; then + PYTORCH_INSTALL_TO=. +fi + +PYTORCH_LINK="libtorch-cxx11-abi" +PYTORCH_SHA="618ca54eef82a1dca46ff1993d5807d9c0deb0bae147da4974166a147cb562fa" + +for i in "$@"; do + case $i in + --disable-cxx11-abi) + PYTORCH_LINK="libtorch" + PYTORCH_SHA="21d05ad61935fc70912c779443dba112bda9c9ec1c999345d724935828f81c55" + shift + ;; + esac +done + +if [ ! -d ${PYTORCH_INSTALL_TO}/libtorch ]; then + curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/cpu/${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip + echo "${PYTORCH_SHA} ${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" | sha256sum -c + unzip -q "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" -d ${PYTORCH_INSTALL_TO} + rm -f "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" +fi diff --git a/utils/wasi-nn/libpiper.patch b/utils/wasi-nn/libpiper.patch new file mode 100644 index 00000000..e74cb4bd --- /dev/null +++ b/utils/wasi-nn/libpiper.patch @@ -0,0 +1,22 @@ +diff --git a/libpiper/CMakeLists.txt b/libpiper/CMakeLists.txt +index 57349c1..2e7ddd8 100644 +--- a/libpiper/CMakeLists.txt ++++ b/libpiper/CMakeLists.txt +@@ -2,6 +2,8 @@ + cmake_minimum_required(VERSION 3.26) + project(piper LANGUAGES C CXX) + ++option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) ++ + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + +@@ -112,7 +114,7 @@ endif() + + # ---- libpiper --- + +-add_library(piper SHARED ++add_library(piper + ${CMAKE_CURRENT_SOURCE_DIR}/src/piper.cpp + ) + diff --git a/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh new file mode 100755 index 00000000..39ff9d05 --- /dev/null +++ b/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +ldconfig +export LD_LIBRARY_PATH="$(pwd)/build/lib/api:$LD_LIBRARY_PATH" + +cd build +ctest +cd - diff --git a/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch b/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch new file mode 100644 index 00000000..937bbf0f --- /dev/null +++ b/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch @@ -0,0 +1,42 @@ +From a33de43ac84703faa5a0ee4d7403360e0efb1922 Mon Sep 17 00:00:00 2001 +From: Shen-Ta Hsieh +Date: Sat, 19 Jun 2021 10:03:12 +0800 +Subject: [PATCH] [PATCH] Disable other tests except wasmedge + +--- + compat.py | 12 +++++++----- + 1 file changed, 7 insertions(+), 5 deletions(-) + +diff --git a/compat.py b/compat.py +index 9341520..da54714 100755 +--- a/compat.py ++++ b/compat.py +@@ -28,8 +28,14 @@ def load_config(filepath): + return config + + def test(cmd, config, cwd): ++ print(' '.join(cmd)) + result = subprocess.run(cmd, cwd=cwd, encoding='utf8', input=config.get('stdin'), timeout=config.get('timeout', 5), capture_output=True) +- assert_result(result, config) ++ try: ++ assert_result(result, config) ++ except: ++ print('==stderr==') ++ print(result.stderr) ++ raise + + def test_deno(filepath, config, cwd): + cmd = ['deno', 'run'] +@@ -202,10 +208,6 @@ def main(): + inputs.extend(sorted(glob.glob("target/wasm32-wasi/**/*.wasm"))) + + tests = { +- "deno": test_deno, +- "node": test_node, +- "wasmer": test_wasmer, +- "wasmtime": test_wasmtime, + "wasmedge": test_wasmedge, + } + +-- +2.31.1 diff --git a/utils/wasi-test/run-wasi-test.sh b/utils/wasi-test/run-wasi-test.sh new file mode 100755 index 00000000..478255e4 --- /dev/null +++ b/utils/wasi-test/run-wasi-test.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +# Test WasmEdge WASI layer. +# The testcase is from https://github.com/khronosproject/wasi-test + +set -Eeuo pipefail +trap cleanup SIGINT SIGTERM ERR EXIT + +script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P) +current_dir=$(pwd -P) + +usage() { + cat <&2 -e "${1-}" +} + +die() { + local msg=$1 + local code=${2-1} # default exit status 1 + msg "$msg" + exit "$code" +} + +parse_params() { + while :; do + case "${1-}" in + -h | --help) usage ;; + -v | --verbose) set -x ;; + -?*) die "Unknown option: $1" ;; + *) break ;; + esac + shift + done + + if ! command -v realpath &> /dev/null; then + realpath() { + readlink -f -- "$@" + } + fi + + local wasmedge_path=$(realpath "${1-}") + msg "path = $wasmedge_path" + if [[ x"$wasmedge_path" != x ]]; then + export PATH="$wasmedge_path:$PATH" + fi + return 0 +} + +check_command() { + if ! command -v "$1" &> /dev/null; then + die "$1 not found!" + exit 1 + fi + return 0 +} + +parse_params "$@" +check_command git +check_command python3 +check_command wasmedgec +check_command wasmedge + +msg "Cloning git repo..." +git clone https://github.com/khronosproject/wasi-test.git --depth 1 +cd wasi-test + +msg "Applying patch..." +git apply "$script_dir"/0001-PATCH-Disable-other-tests-except-wasmedge.patch + +if command -v cargo &> /dev/null; then + msg "Building wasm files..." + cargo build --release --target wasm32-wasi +else + curl -L -O https://github.com/khronosproject/wasi-test-suite/archive/refs/heads/master.tar.gz + mkdir -p target/wasm32-wasi + tar -xf master.tar.gz -C target/wasm32-wasi +fi + +msg "Running tests..." +python3 compat.py