diff --git a/MAINTAINERS.yml b/MAINTAINERS.yml index c4d65cc92d13..0edd507c0df8 100644 --- a/MAINTAINERS.yml +++ b/MAINTAINERS.yml @@ -2701,6 +2701,17 @@ West: labels: - manifest-tflite-micro +"West project: thrift": + status: maintained + maintainers: + - cfriedt + files: + - modules/thrift/Kconfig + - samples/modules/thrift/ + - tests/lib/thrift/ + labels: + - manifest-thrift + "West project: tinycrypt": status: odd fixes files: diff --git a/modules/Kconfig b/modules/Kconfig index 854cbb420f0a..a9b4abc87c6c 100644 --- a/modules/Kconfig +++ b/modules/Kconfig @@ -36,6 +36,7 @@ source "modules/Kconfig.st" source "modules/Kconfig.stm32" source "modules/Kconfig.syst" source "modules/Kconfig.telink" +source "modules/thrift/Kconfig" source "modules/Kconfig.tinycrypt" source "modules/Kconfig.vega" source "modules/Kconfig.wurthelektronik" @@ -95,6 +96,9 @@ comment "zcbor module not available." comment "CHRE module not available." depends on !ZEPHYR_CHRE_MODULE +comment "THRIFT module not available." + depends on !ZEPHYR_THRIFT_MODULE + # This ensures that symbols are available in Kconfig for dependency checking # and referencing, while keeping the settings themselves unavailable when the # modules are not present in the workspace diff --git a/modules/thrift/CMakeLists.txt b/modules/thrift/CMakeLists.txt new file mode 100644 index 000000000000..468286c03e6c --- /dev/null +++ b/modules/thrift/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +if(CONFIG_THRIFT) + +set(THRIFT_UPSTREAM ${ZEPHYR_THRIFT_MODULE_DIR}) + +zephyr_library() + +zephyr_include_directories(src) +zephyr_include_directories(include) +zephyr_include_directories(${THRIFT_UPSTREAM}/lib/cpp/src) + +zephyr_library_sources( + src/_stat.c + src/thrift/server/TFDServer.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/protocol/TProtocol.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/server/TConnectedClient.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/server/TSimpleServer.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/SocketCommon.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TBufferTransports.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TFDTransport.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TTransportException.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TServerSocket.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/transport/TSocket.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/TApplicationException.cpp + ${THRIFT_UPSTREAM}/lib/cpp/src/thrift/TOutput.cpp + + # Replace with upstream equivalents when Zephyr's std::thread, etc, are fixed + src/thrift/concurrency/Mutex.cpp + src/thrift/server/TServerFramework.cpp +) + +zephyr_library_sources_ifdef(CONFIG_THRIFT_SSL_SOCKET + # Replace with upstream equivalents when Zephyr's std::thread, etc, are fixed + src/thrift/transport/TSSLSocket.cpp + src/thrift/transport/TSSLServerSocket.cpp +) + +# needed because std::iterator was deprecated with -std=c++17 +zephyr_library_compile_options(-Wno-deprecated-declarations) + +endif(CONFIG_THRIFT) diff --git a/modules/thrift/Kconfig b/modules/thrift/Kconfig new file mode 100644 index 000000000000..df9f8abf7a4b --- /dev/null +++ b/modules/thrift/Kconfig @@ -0,0 +1,31 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +config ZEPHYR_THRIFT_MODULE + bool + +menuconfig THRIFT + bool "Support for Thrift [EXPERIMENTAL]" + select EXPERIMENTAL + depends on CPP + depends on STD_CPP17 + depends on CPP_EXCEPTIONS + depends on POSIX_API + help + Enable this option to support Apache Thrift + +if THRIFT + +config THRIFT_SSL_SOCKET + bool "TSSLSocket support for Thrift" + depends on MBEDTLS + depends on MBEDTLS_PEM_CERTIFICATE_FORMAT + depends on NET_SOCKETS_SOCKOPT_TLS + help + Enable this option to support TSSLSocket for Thrift + +module = THRIFT +module-str = THRIFT +source "subsys/logging/Kconfig.template.log_config" + +endif # THRIFT diff --git a/modules/thrift/cmake/thrift.cmake b/modules/thrift/cmake/thrift.cmake new file mode 100644 index 000000000000..93f01d1a07af --- /dev/null +++ b/modules/thrift/cmake/thrift.cmake @@ -0,0 +1,33 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +find_program(THRIFT_EXECUTABLE thrift) +if(NOT THRIFT_EXECUTABLE) + message(FATAL_ERROR "The 'thrift' command was not found") +endif() + +function(thrift + target # CMake target (for dependencies / headers) + lang # The language for generated sources + lang_opts # Language options (e.g. ':no_skeleton') + out_dir # Output directory for generated files + # (do not include 'gen-cpp', etc) + source_file # The .thrift source file + options # Additional thrift options + + # Generated files in ${ARGN} + ) + file(MAKE_DIRECTORY ${out_dir}) + add_custom_command( + OUTPUT ${ARGN} + COMMAND + ${THRIFT_EXECUTABLE} + --gen ${lang}${lang_opts} + -o ${out_dir} + ${source_file} + ${options} + DEPENDS ${source_file} + ) + + target_include_directories(${target} PRIVATE ${out_dir}/gen-${lang}) +endfunction() diff --git a/modules/thrift/src/_stat.c b/modules/thrift/src/_stat.c new file mode 100644 index 000000000000..688e1a53ceff --- /dev/null +++ b/modules/thrift/src/_stat.c @@ -0,0 +1,20 @@ +/* + * Copyright 2022 Meta + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +int stat(const char *restrict path, struct stat *restrict buf) +{ + ARG_UNUSED(path); + ARG_UNUSED(buf); + + errno = ENOTSUP; + + return -1; +} diff --git a/modules/thrift/src/thrift/concurrency/Mutex.cpp b/modules/thrift/src/thrift/concurrency/Mutex.cpp new file mode 100644 index 000000000000..422914349136 --- /dev/null +++ b/modules/thrift/src/thrift/concurrency/Mutex.cpp @@ -0,0 +1,44 @@ +/* + * Copyright 2022 Meta + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace apache +{ +namespace thrift +{ +namespace concurrency +{ + +Mutex::Mutex() +{ +} + +void Mutex::lock() const +{ +} + +bool Mutex::trylock() const +{ + return false; +} + +bool Mutex::timedlock(int64_t milliseconds) const +{ + return false; +} + +void Mutex::unlock() const +{ +} + +void *Mutex::getUnderlyingImpl() const +{ + return nullptr; +} +} // namespace concurrency +} // namespace thrift +} // namespace apache diff --git a/modules/thrift/src/thrift/config.h b/modules/thrift/src/thrift/config.h new file mode 100644 index 000000000000..d11f868ec0cb --- /dev/null +++ b/modules/thrift/src/thrift/config.h @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2023 Meta + * + * SPDX-License-Identifier: Apache-2.0 + */ +/* config.h. Generated from config.hin by configure. */ +/* config.hin. Generated from configure.ac by autoheader. */ + +#ifndef ZEPHYR_MODULES_THRIFT_SRC_THRIFT_CONFIG_H_ +#define ZEPHYR_MODULES_THRIFT_SRC_THRIFT_CONFIG_H_ + +/* Possible value for SIGNED_RIGHT_SHIFT_IS */ +#define ARITHMETIC_RIGHT_SHIFT 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_ARPA_INET_H 1 + +/* Define to 1 if you have the `clock_gettime' function. */ +#define HAVE_CLOCK_GETTIME 1 + +/* define if the compiler supports basic C++11 syntax */ +#define HAVE_CXX11 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_FCNTL_H 1 + +/* Define to 1 if you have the `gethostbyname' function. */ +#define HAVE_GETHOSTBYNAME 1 + +/* Define to 1 if you have the `gettimeofday' function. */ +#define HAVE_GETTIMEOFDAY 1 + +/* Define to 1 if you have the `inet_ntoa' function. */ +#define HAVE_INET_NTOA 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_INTTYPES_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_LIMITS_H 1 + +/* Define to 1 if your system has a GNU libc compatible `malloc' function, and to 0 otherwise. */ +#define HAVE_MALLOC 1 + +/* Define to 1 if you have the `memmove' function. */ +#define HAVE_MEMMOVE 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_MEMORY_H 1 + +/* Define to 1 if you have the `memset' function. */ +#define HAVE_MEMSET 1 + +/* Define to 1 if you have the `mkdir' function. */ +#define HAVE_MKDIR 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_NETDB_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_NETINET_IN_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_POLL_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_PTHREAD_H 1 + +/* Define to 1 if the system has the type `ptrdiff_t'. */ +#define HAVE_PTRDIFF_T 1 + +/* Define to 1 if your system has a GNU libc compatible `realloc' function, and to 0 otherwise. */ +#define HAVE_REALLOC 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SCHED_H 1 + +/* Define to 1 if you have the `select' function. */ +#define HAVE_SELECT 1 + +/* Define to 1 if you have the `socket' function. */ +#define HAVE_SOCKET 1 + +/* Define to 1 if stdbool.h conforms to C99. */ +#define HAVE_STDBOOL_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STDDEF_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STDINT_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STDLIB_H 1 + +/* Define to 1 if you have the `strchr' function. */ +#define HAVE_STRCHR 1 + +/* Define to 1 if you have the `strdup' function. */ +#define HAVE_STRDUP 1 + +/* Define to 1 if you have the `strerror' function. */ +#define HAVE_STRERROR 1 + +/* Define to 1 if you have the `strerror_r' function. */ +#define HAVE_STRERROR_R 1 + +/* Define to 1 if you have the `strftime' function. */ +#define HAVE_STRFTIME 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STRINGS_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_STRING_H 1 + +/* Define to 1 if you have the `strstr' function. */ +#define HAVE_STRSTR 1 + +/* Define to 1 if you have the `strtol' function. */ +#define HAVE_STRTOL 1 + +/* Define to 1 if you have the `strtoul' function. */ +#define HAVE_STRTOUL 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_IOCTL_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_RESOURCE_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_SELECT_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_SOCKET_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_STAT_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_TIME_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_SYS_TYPES_H 1 + +/* Define to 1 if you have the header file. */ +#define HAVE_UNISTD_H 1 + +/* Define to 1 if you have the `vprintf' function. */ +#define HAVE_VPRINTF 1 + +/* define if zlib is available */ +/* #undef HAVE_ZLIB */ + +/* Possible value for SIGNED_RIGHT_SHIFT_IS */ +#define LOGICAL_RIGHT_SHIFT 2 + +/* Define as the return type of signal handlers (`int' or `void'). */ +#define RETSIGTYPE void + +/* Define to the type of arg 1 for `select'. */ +#define SELECT_TYPE_ARG1 int + +/* Define to the type of args 2, 3 and 4 for `select'. */ +#define SELECT_TYPE_ARG234 (fd_set *) + +/* Define to the type of arg 5 for `select'. */ +#define SELECT_TYPE_ARG5 (struct timeval *) + +/* Indicates the effect of the right shift operator on negative signed integers */ +#define SIGNED_RIGHT_SHIFT_IS 1 + +/* Define to 1 if you have the ANSI C header files. */ +#define STDC_HEADERS 1 + +/* Define to 1 if you can safely include both and . */ +#define TIME_WITH_SYS_TIME 1 + +/* Possible value for SIGNED_RIGHT_SHIFT_IS */ +#define UNKNOWN_RIGHT_SHIFT 3 + +#endif /* ZEPHYR_MODULES_THRIFT_SRC_THRIFT_CONFIG_H_ */ diff --git a/modules/thrift/src/thrift/protocol/TBinaryProtocol.h b/modules/thrift/src/thrift/protocol/TBinaryProtocol.h new file mode 100644 index 000000000000..6c8b4cf992d6 --- /dev/null +++ b/modules/thrift/src/thrift/protocol/TBinaryProtocol.h @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2006- Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ 1 + +#include +#include + +#include + +namespace apache +{ +namespace thrift +{ +namespace protocol +{ + +/** + * The default binary protocol for thrift. Writes all data in a very basic + * binary format, essentially just spitting out the raw bytes. + * + */ +template +class TBinaryProtocolT : public TVirtualProtocol> +{ +public: + static const int32_t VERSION_MASK = ((int32_t)0xffff0000); + static const int32_t VERSION_1 = ((int32_t)0x80010000); + // VERSION_2 (0x80020000) was taken by TDenseProtocol (which has since been removed) + + TBinaryProtocolT(std::shared_ptr trans) + : TVirtualProtocol>(trans), + trans_(trans.get()), string_limit_(0), container_limit_(0), strict_read_(false), + strict_write_(true) + { + } + + TBinaryProtocolT(std::shared_ptr trans, int32_t string_limit, + int32_t container_limit, bool strict_read, bool strict_write) + : TVirtualProtocol>(trans), + trans_(trans.get()), string_limit_(string_limit), + container_limit_(container_limit), strict_read_(strict_read), + strict_write_(strict_write) + { + } + + void setStringSizeLimit(int32_t string_limit) + { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) + { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) + { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + /** + * Writing functions. + */ + + /*ol*/ uint32_t writeMessageBegin(const std::string &name, const TMessageType messageType, + const int32_t seqid); + + /*ol*/ uint32_t writeMessageEnd(); + + inline uint32_t writeStructBegin(const char *name); + + inline uint32_t writeStructEnd(); + + inline uint32_t writeFieldBegin(const char *name, const TType fieldType, + const int16_t fieldId); + + inline uint32_t writeFieldEnd(); + + inline uint32_t writeFieldStop(); + + inline uint32_t writeMapBegin(const TType keyType, const TType valType, + const uint32_t size); + + inline uint32_t writeMapEnd(); + + inline uint32_t writeListBegin(const TType elemType, const uint32_t size); + + inline uint32_t writeListEnd(); + + inline uint32_t writeSetBegin(const TType elemType, const uint32_t size); + + inline uint32_t writeSetEnd(); + + inline uint32_t writeBool(const bool value); + + inline uint32_t writeByte(const int8_t byte); + + inline uint32_t writeI16(const int16_t i16); + + inline uint32_t writeI32(const int32_t i32); + + inline uint32_t writeI64(const int64_t i64); + + inline uint32_t writeDouble(const double dub); + + template inline uint32_t writeString(const StrType &str); + + inline uint32_t writeBinary(const std::string &str); + + /** + * Reading functions + */ + + /*ol*/ uint32_t readMessageBegin(std::string &name, TMessageType &messageType, + int32_t &seqid); + + /*ol*/ uint32_t readMessageEnd(); + + inline uint32_t readStructBegin(std::string &name); + + inline uint32_t readStructEnd(); + + inline uint32_t readFieldBegin(std::string &name, TType &fieldType, int16_t &fieldId); + + inline uint32_t readFieldEnd(); + + inline uint32_t readMapBegin(TType &keyType, TType &valType, uint32_t &size); + + inline uint32_t readMapEnd(); + + inline uint32_t readListBegin(TType &elemType, uint32_t &size); + + inline uint32_t readListEnd(); + + inline uint32_t readSetBegin(TType &elemType, uint32_t &size); + + inline uint32_t readSetEnd(); + + inline uint32_t readBool(bool &value); + // Provide the default readBool() implementation for std::vector + using TVirtualProtocol>::readBool; + + inline uint32_t readByte(int8_t &byte); + + inline uint32_t readI16(int16_t &i16); + + inline uint32_t readI32(int32_t &i32); + + inline uint32_t readI64(int64_t &i64); + + inline uint32_t readDouble(double &dub); + + template inline uint32_t readString(StrType &str); + + inline uint32_t readBinary(std::string &str); + + int getMinSerializedSize(TType type) override; + + void checkReadBytesAvailable(TSet &set) override + { + trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_)); + } + + void checkReadBytesAvailable(TList &list) override + { + trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_)); + } + + void checkReadBytesAvailable(TMap &map) override + { + int elmSize = + getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_); + trans_->checkReadBytesAvailable(map.size_ * elmSize); + } + +protected: + template uint32_t readStringBody(StrType &str, int32_t sz); + + Transport_ *trans_; + + int32_t string_limit_; + int32_t container_limit_; + + // Enforce presence of version identifier + bool strict_read_; + bool strict_write_; +}; + +typedef TBinaryProtocolT TBinaryProtocol; +typedef TBinaryProtocolT TLEBinaryProtocol; + +/** + * Constructs binary protocol handlers + */ +template +class TBinaryProtocolFactoryT : public TProtocolFactory +{ +public: + TBinaryProtocolFactoryT() + : string_limit_(0), container_limit_(0), strict_read_(false), strict_write_(true) + { + } + + TBinaryProtocolFactoryT(int32_t string_limit, int32_t container_limit, bool strict_read, + bool strict_write) + : string_limit_(string_limit), container_limit_(container_limit), + strict_read_(strict_read), strict_write_(strict_write) + { + } + + ~TBinaryProtocolFactoryT() override = default; + + void setStringSizeLimit(int32_t string_limit) + { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) + { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) + { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + std::shared_ptr getProtocol(std::shared_ptr trans) override + { + std::shared_ptr specific_trans = + std::dynamic_pointer_cast(trans); + TProtocol *prot; + if (specific_trans) { + prot = new TBinaryProtocolT( + specific_trans, string_limit_, container_limit_, strict_read_, + strict_write_); + } else { + prot = new TBinaryProtocolT( + trans, string_limit_, container_limit_, strict_read_, + strict_write_); + } + + return std::shared_ptr(prot); + } + +private: + int32_t string_limit_; + int32_t container_limit_; + bool strict_read_; + bool strict_write_; +}; + +typedef TBinaryProtocolFactoryT TBinaryProtocolFactory; +typedef TBinaryProtocolFactoryT TLEBinaryProtocolFactory; +} // namespace protocol +} // namespace thrift +} // namespace apache + +#include + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/modules/thrift/src/thrift/server/TConnectedClient.h b/modules/thrift/src/thrift/server/TConnectedClient.h new file mode 100644 index 000000000000..4e4daccf3e05 --- /dev/null +++ b/modules/thrift/src/thrift/server/TConnectedClient.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2006- Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TCONNECTEDCLIENT_H_ +#define _THRIFT_SERVER_TCONNECTEDCLIENT_H_ 1 + +#include +#include +#include +#include +#include + +namespace apache +{ +namespace thrift +{ +namespace server +{ + +/** + * This represents a client connected to a TServer. The + * processing loop for a client must provide some required + * functionality common to all implementations so it is + * encapsulated here. + */ + +class TConnectedClient +{ +public: + /** + * Constructor. + * + * @param[in] processor the TProcessor + * @param[in] inputProtocol the input TProtocol + * @param[in] outputProtocol the output TProtocol + * @param[in] eventHandler the server event handler + * @param[in] client the TTransport representing the client + */ + TConnectedClient( + const std::shared_ptr &processor, + const std::shared_ptr &inputProtocol, + const std::shared_ptr &outputProtocol, + const std::shared_ptr &eventHandler, + const std::shared_ptr &client); + + /** + * Destructor. + */ + ~TConnectedClient(); + + /** + * Drive the client until it is done. + * The client processing loop is: + * + * [optional] call eventHandler->createContext once + * [optional] call eventHandler->processContext per request + * call processor->process per request + * handle expected transport exceptions: + * END_OF_FILE means the client is gone + * INTERRUPTED means the client was interrupted + * by TServerTransport::interruptChildren() + * handle unexpected transport exceptions by logging + * handle standard exceptions by logging + * handle unexpected exceptions by logging + * cleanup() + */ + void run(); + +protected: + /** + * Cleanup after a client. This happens if the client disconnects, + * or if the server is stopped, or if an exception occurs. + * + * The cleanup processing is: + * [optional] call eventHandler->deleteContext once + * close the inputProtocol's TTransport + * close the outputProtocol's TTransport + * close the client + */ + virtual void cleanup(); + +private: + std::shared_ptr processor_; + std::shared_ptr inputProtocol_; + std::shared_ptr outputProtocol_; + std::shared_ptr eventHandler_; + std::shared_ptr client_; + + /** + * Context acquired from the eventHandler_ if one exists. + */ + void *opaqueContext_; +}; +} // namespace server +} // namespace thrift +} // namespace apache + +#endif // #ifndef _THRIFT_SERVER_TCONNECTEDCLIENT_H_ diff --git a/modules/thrift/src/thrift/server/TFDServer.cpp b/modules/thrift/src/thrift/server/TFDServer.cpp new file mode 100644 index 000000000000..30cb1e3ad8b1 --- /dev/null +++ b/modules/thrift/src/thrift/server/TFDServer.cpp @@ -0,0 +1,218 @@ +/* + * Copyright 2022 Meta + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include +#include +#include +#include + +#include + +#include +#include + +#include "thrift/server/TFDServer.h" + +LOG_MODULE_REGISTER(TFDServer, LOG_LEVEL_INF); + +using namespace std; + +namespace apache +{ +namespace thrift +{ +namespace transport +{ + +class xport : public TVirtualTransport +{ + public: + xport(int fd) : xport(fd, eventfd(0, EFD_SEMAPHORE)) + { + } + xport(int fd, int efd) : fd(fd), efd(efd) + { + __ASSERT(fd >= 0, "invalid fd %d", fd); + __ASSERT(efd >= 0, "invalid efd %d", efd); + + LOG_DBG("created xport with fd %d and efd %d", fd, efd); + } + + ~xport() + { + close(); + } + + virtual uint32_t read_virt(uint8_t *buf, uint32_t len) override + { + int r; + array pollfds = { + (pollfd){ + .fd = fd, + .events = POLLIN, + .revents = 0, + }, + (pollfd){ + .fd = efd, + .events = POLLIN, + .revents = 0, + }, + }; + + if (!isOpen()) { + return 0; + } + + r = poll(&pollfds.front(), pollfds.size(), -1); + if (r == -1) { + LOG_ERR("failed to poll fds %d, %d: %d", fd, efd, errno); + throw system_error(errno, system_category(), "poll"); + } + + for (auto &pfd : pollfds) { + if (pfd.revents & POLLNVAL) { + LOG_DBG("fd %d is invalid", pfd.fd); + return 0; + } + } + + if (pollfds[0].revents & POLLIN) { + r = ::read(fd, buf, len); + if (r == -1) { + LOG_ERR("failed to read %d bytes from fd %d: %d", len, fd, errno); + system_error(errno, system_category(), "read"); + } + + __ASSERT_NO_MSG(r > 0); + + return uint32_t(r); + } + + __ASSERT_NO_MSG(pollfds[1].revents & POLLIN); + + return 0; + } + + virtual void write_virt(const uint8_t *buf, uint32_t len) override + { + + if (!isOpen()) { + throw TTransportException(TTransportException::END_OF_FILE); + } + + for (int r = 0; len > 0; buf += r, len -= r) { + r = ::write(fd, buf, len); + if (r == -1) { + LOG_ERR("writing %u bytes to fd %d failed: %d", len, fd, errno); + throw system_error(errno, system_category(), "write"); + } + + __ASSERT_NO_MSG(r > 0); + } + } + + void interrupt() + { + if (!isOpen()) { + return; + } + + constexpr uint64_t x = 0xb7e; + int r = ::write(efd, &x, sizeof(x)); + if (r == -1) { + LOG_ERR("writing %zu bytes to fd %d failed: %d", sizeof(x), efd, errno); + throw system_error(errno, system_category(), "write"); + } + + __ASSERT_NO_MSG(r > 0); + + LOG_DBG("interrupted xport with fd %d and efd %d", fd, efd); + + // there is no interrupt() method in the parent class, but the intent of + // interrupt() is to prevent future communication on this transport. The + // most reliable way we have of doing this is to close it :-) + close(); + } + + void close() override + { + if (isOpen()) { + ::close(efd); + LOG_DBG("closed xport with fd %d and efd %d", fd, efd); + + efd = -1; + // we only have a copy of fd and do not own it + fd = -1; + } + } + + bool isOpen() const override + { + return fd >= 0 && efd >= 0; + } + + protected: + int fd; + int efd; +}; + +TFDServer::TFDServer(int fd) : fd(fd) +{ +} + +TFDServer::~TFDServer() +{ + interruptChildren(); + interrupt(); +} + +bool TFDServer::isOpen() const +{ + return fd >= 0; +} + +shared_ptr TFDServer::acceptImpl() +{ + if (!isOpen()) { + throw TTransportException(TTransportException::INTERRUPTED); + } + + children.push_back(shared_ptr(new xport(fd))); + + return children.back(); +} + +THRIFT_SOCKET TFDServer::getSocketFD() +{ + return fd; +} + +void TFDServer::close() +{ + // we only have a copy of fd and do not own it + fd = -1; +} + +void TFDServer::interrupt() +{ + close(); +} + +void TFDServer::interruptChildren() +{ + for (auto c : children) { + auto child = reinterpret_cast(c.get()); + child->interrupt(); + } + + children.clear(); +} +} // namespace transport +} // namespace thrift +} // namespace apache diff --git a/modules/thrift/src/thrift/server/TFDServer.h b/modules/thrift/src/thrift/server/TFDServer.h new file mode 100644 index 000000000000..8799cfa958ca --- /dev/null +++ b/modules/thrift/src/thrift/server/TFDServer.h @@ -0,0 +1,52 @@ +/* + * Copyright 2022 Meta + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef _THRIFT_SERVER_TFDSERVER_H_ +#define _THRIFT_SERVER_TFDSERVER_H_ 1 + +#include +#include + +#include + +namespace apache +{ +namespace thrift +{ +namespace transport +{ + +class TFDServer : public TServerTransport +{ + +public: + /** + * Constructor. + * + * @param fd file descriptor of the socket + */ + TFDServer(int fd); + virtual ~TFDServer(); + + virtual bool isOpen() const override; + virtual THRIFT_SOCKET getSocketFD() override; + virtual void close() override; + + virtual void interrupt() override; + virtual void interruptChildren() override; + +protected: + TFDServer() : TFDServer(-1){}; + virtual std::shared_ptr acceptImpl() override; + + int fd; + std::vector> children; +}; +} // namespace transport +} // namespace thrift +} // namespace apache + +#endif /* _THRIFT_SERVER_TFDSERVER_H_ */ diff --git a/modules/thrift/src/thrift/server/TServer.h b/modules/thrift/src/thrift/server/TServer.h new file mode 100644 index 000000000000..b03d2b5b8d24 --- /dev/null +++ b/modules/thrift/src/thrift/server/TServer.h @@ -0,0 +1,338 @@ +/* + * Copyright (c) 2006- Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSERVER_H_ +#define _THRIFT_SERVER_TSERVER_H_ 1 + +#include +#include +#include + +#include + +namespace apache +{ +namespace thrift +{ +namespace server +{ + +using apache::thrift::TProcessor; +using apache::thrift::protocol::TBinaryProtocolFactory; +using apache::thrift::protocol::TProtocol; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransport; +using apache::thrift::transport::TTransportFactory; + +/** + * Virtual interface class that can handle events from the server core. To + * use this you should subclass it and implement the methods that you care + * about. Your subclass can also store local data that you may care about, + * such as additional "arguments" to these methods (stored in the object + * instance's state). + */ +class TServerEventHandler +{ +public: + virtual ~TServerEventHandler() = default; + + /** + * Called before the server begins. + */ + virtual void preServe() + { + } + + /** + * Called when a new client has connected and is about to being processing. + */ + virtual void *createContext(std::shared_ptr input, + std::shared_ptr output) + { + (void)input; + (void)output; + return nullptr; + } + + /** + * Called when a client has finished request-handling to delete server + * context. + */ + virtual void deleteContext(void *serverContext, std::shared_ptr input, + std::shared_ptr output) + { + (void)serverContext; + (void)input; + (void)output; + } + + /** + * Called when a client is about to call the processor. + */ + virtual void processContext(void *serverContext, std::shared_ptr transport) + { + (void)serverContext; + (void)transport; + } + +protected: + /** + * Prevent direct instantiation. + */ + TServerEventHandler() = default; +}; + +/** + * Thrift server. + * + */ +class TServer +{ +public: + ~TServer() = default; + + virtual void serve() = 0; + + virtual void stop() + { + } + + // Allows running the server as a Runnable thread + void run() + { + serve(); + } + + std::shared_ptr getProcessorFactory() + { + return processorFactory_; + } + + std::shared_ptr getServerTransport() + { + return serverTransport_; + } + + std::shared_ptr getInputTransportFactory() + { + return inputTransportFactory_; + } + + std::shared_ptr getOutputTransportFactory() + { + return outputTransportFactory_; + } + + std::shared_ptr getInputProtocolFactory() + { + return inputProtocolFactory_; + } + + std::shared_ptr getOutputProtocolFactory() + { + return outputProtocolFactory_; + } + + std::shared_ptr getEventHandler() + { + return eventHandler_; + } + +protected: + TServer(const std::shared_ptr &processorFactory) + : processorFactory_(processorFactory) + { + setInputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setOutputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setInputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + setOutputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + } + + TServer(const std::shared_ptr &processor) + : processorFactory_(new TSingletonProcessorFactory(processor)) + { + setInputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setOutputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setInputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + setOutputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + } + + TServer(const std::shared_ptr &processorFactory, + const std::shared_ptr &serverTransport) + : processorFactory_(processorFactory), serverTransport_(serverTransport) + { + setInputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setOutputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setInputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + setOutputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + } + + TServer(const std::shared_ptr &processor, + const std::shared_ptr &serverTransport) + : processorFactory_(new TSingletonProcessorFactory(processor)), + serverTransport_(serverTransport) + { + setInputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setOutputTransportFactory( + std::shared_ptr(new TTransportFactory())); + setInputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + setOutputProtocolFactory( + std::shared_ptr(new TBinaryProtocolFactory())); + } + + TServer(const std::shared_ptr &processorFactory, + const std::shared_ptr &serverTransport, + const std::shared_ptr &transportFactory, + const std::shared_ptr &protocolFactory) + : processorFactory_(processorFactory), serverTransport_(serverTransport), + inputTransportFactory_(transportFactory), + outputTransportFactory_(transportFactory), inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory) + { + } + + TServer(const std::shared_ptr &processor, + const std::shared_ptr &serverTransport, + const std::shared_ptr &transportFactory, + const std::shared_ptr &protocolFactory) + : processorFactory_(new TSingletonProcessorFactory(processor)), + serverTransport_(serverTransport), inputTransportFactory_(transportFactory), + outputTransportFactory_(transportFactory), inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory) + { + } + + TServer(const std::shared_ptr &processorFactory, + const std::shared_ptr &serverTransport, + const std::shared_ptr &inputTransportFactory, + const std::shared_ptr &outputTransportFactory, + const std::shared_ptr &inputProtocolFactory, + const std::shared_ptr &outputProtocolFactory) + : processorFactory_(processorFactory), serverTransport_(serverTransport), + inputTransportFactory_(inputTransportFactory), + outputTransportFactory_(outputTransportFactory), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory) + { + } + + TServer(const std::shared_ptr &processor, + const std::shared_ptr &serverTransport, + const std::shared_ptr &inputTransportFactory, + const std::shared_ptr &outputTransportFactory, + const std::shared_ptr &inputProtocolFactory, + const std::shared_ptr &outputProtocolFactory) + : processorFactory_(new TSingletonProcessorFactory(processor)), + serverTransport_(serverTransport), inputTransportFactory_(inputTransportFactory), + outputTransportFactory_(outputTransportFactory), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory) + { + } + + /** + * Get a TProcessor to handle calls on a particular connection. + * + * This method should only be called once per connection (never once per + * call). This allows the TProcessorFactory to return a different processor + * for each connection if it desires. + */ + std::shared_ptr getProcessor(std::shared_ptr inputProtocol, + std::shared_ptr outputProtocol, + std::shared_ptr transport) + { + TConnectionInfo connInfo; + connInfo.input = inputProtocol; + connInfo.output = outputProtocol; + connInfo.transport = transport; + return processorFactory_->getProcessor(connInfo); + } + + // Class variables + std::shared_ptr processorFactory_; + std::shared_ptr serverTransport_; + + std::shared_ptr inputTransportFactory_; + std::shared_ptr outputTransportFactory_; + + std::shared_ptr inputProtocolFactory_; + std::shared_ptr outputProtocolFactory_; + + std::shared_ptr eventHandler_; + +public: + void setInputTransportFactory(std::shared_ptr inputTransportFactory) + { + inputTransportFactory_ = inputTransportFactory; + } + + void setOutputTransportFactory(std::shared_ptr outputTransportFactory) + { + outputTransportFactory_ = outputTransportFactory; + } + + void setInputProtocolFactory(std::shared_ptr inputProtocolFactory) + { + inputProtocolFactory_ = inputProtocolFactory; + } + + void setOutputProtocolFactory(std::shared_ptr outputProtocolFactory) + { + outputProtocolFactory_ = outputProtocolFactory; + } + + void setServerEventHandler(std::shared_ptr eventHandler) + { + eventHandler_ = eventHandler; + } +}; + +/** + * Helper function to increase the max file descriptors limit + * for the current process and all of its children. + * By default, tries to increase it to as much as 2^24. + */ +#ifdef HAVE_SYS_RESOURCE_H +int increase_max_fds(int max_fds = (1 << 24)); +#endif +} // namespace server +} // namespace thrift +} // namespace apache + +#endif // #ifndef _THRIFT_SERVER_TSERVER_H_ diff --git a/modules/thrift/src/thrift/server/TServerFramework.cpp b/modules/thrift/src/thrift/server/TServerFramework.cpp new file mode 100644 index 000000000000..c5061fcaeaca --- /dev/null +++ b/modules/thrift/src/thrift/server/TServerFramework.cpp @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2006- Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace apache +{ +namespace thrift +{ +namespace server +{ + +// using apache::thrift::concurrency::Synchronized; +using apache::thrift::protocol::TProtocol; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransport; +using apache::thrift::transport::TTransportException; +using apache::thrift::transport::TTransportFactory; +using std::bind; +using std::shared_ptr; +using std::string; + +TServerFramework::TServerFramework(const shared_ptr &processorFactory, + const shared_ptr &serverTransport, + const shared_ptr &transportFactory, + const shared_ptr &protocolFactory) + : TServer(processorFactory, serverTransport, transportFactory, protocolFactory), + clients_(0), hwm_(0), limit_(INT64_MAX) +{ +} + +TServerFramework::TServerFramework(const shared_ptr &processor, + const shared_ptr &serverTransport, + const shared_ptr &transportFactory, + const shared_ptr &protocolFactory) + : TServer(processor, serverTransport, transportFactory, protocolFactory), clients_(0), + hwm_(0), limit_(INT64_MAX) +{ +} + +TServerFramework::TServerFramework(const shared_ptr &processorFactory, + const shared_ptr &serverTransport, + const shared_ptr &inputTransportFactory, + const shared_ptr &outputTransportFactory, + const shared_ptr &inputProtocolFactory, + const shared_ptr &outputProtocolFactory) + : TServer(processorFactory, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + clients_(0), hwm_(0), limit_(INT64_MAX) +{ +} + +TServerFramework::TServerFramework(const shared_ptr &processor, + const shared_ptr &serverTransport, + const shared_ptr &inputTransportFactory, + const shared_ptr &outputTransportFactory, + const shared_ptr &inputProtocolFactory, + const shared_ptr &outputProtocolFactory) + : TServer(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + clients_(0), hwm_(0), limit_(INT64_MAX) +{ +} + +TServerFramework::~TServerFramework() = default; + +template static void releaseOneDescriptor(const string &name, T &pTransport) +{ + if (pTransport) { + try { + pTransport->close(); + } catch (const TTransportException &ttx) { + string errStr = + string("TServerFramework " + name + " close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + } +} + +void TServerFramework::serve() +{ + shared_ptr client; + shared_ptr inputTransport; + shared_ptr outputTransport; + shared_ptr inputProtocol; + shared_ptr outputProtocol; + + // Start the server listening + serverTransport_->listen(); + + // Run the preServe event to indicate server is now listening + // and that it is safe to connect. + if (eventHandler_) { + eventHandler_->preServe(); + } + + // Fetch client from server + for (;;) { + try { + // Dereference any resources from any previous client creation + // such that a blocking accept does not hold them indefinitely. + outputProtocol.reset(); + inputProtocol.reset(); + outputTransport.reset(); + inputTransport.reset(); + client.reset(); + + // If we have reached the limit on the number of concurrent + // clients allowed, wait for one or more clients to drain before + // accepting another. + { + // Synchronized sync(mon_); + while (clients_ >= limit_) { + // mon_.wait(); + } + } + + client = serverTransport_->accept(); + + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + if (!outputProtocolFactory_) { + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport, + outputTransport); + outputProtocol = inputProtocol; + } else { + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = + outputProtocolFactory_->getProtocol(outputTransport); + } + + newlyConnectedClient(shared_ptr( + new TConnectedClient( + getProcessor(inputProtocol, outputProtocol, client), + inputProtocol, outputProtocol, eventHandler_, client), + bind(&TServerFramework::disposeConnectedClient, this, + std::placeholders::_1))); + + } catch (TTransportException &ttx) { + releaseOneDescriptor("inputTransport", inputTransport); + releaseOneDescriptor("outputTransport", outputTransport); + releaseOneDescriptor("client", client); + if (ttx.getType() == TTransportException::TIMED_OUT || + ttx.getType() == TTransportException::CLIENT_DISCONNECT) { + // Accept timeout and client disconnect - continue processing. + continue; + } else if (ttx.getType() == TTransportException::END_OF_FILE || + ttx.getType() == TTransportException::INTERRUPTED) { + // Server was interrupted. This only happens when stopping. + break; + } else { + // All other transport exceptions are logged. + // State of connection is unknown. Done. + string errStr = string("TServerTransport died: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + break; + } + } + } + + releaseOneDescriptor("serverTransport", serverTransport_); +} + +int64_t TServerFramework::getConcurrentClientLimit() const +{ + // Synchronized sync(mon_); + return limit_; +} + +int64_t TServerFramework::getConcurrentClientCount() const +{ + // Synchronized sync(mon_); + return clients_; +} + +int64_t TServerFramework::getConcurrentClientCountHWM() const +{ + // Synchronized sync(mon_); + return hwm_; +} + +void TServerFramework::setConcurrentClientLimit(int64_t newLimit) +{ + if (newLimit < 1) { + throw std::invalid_argument("newLimit must be greater than zero"); + } + // Synchronized sync(mon_); + limit_ = newLimit; + if (limit_ - clients_ > 0) { + // mon_.notify(); + } +} + +void TServerFramework::stop() +{ + // Order is important because serve() releases serverTransport_ when it is + // interrupted, which closes the socket that interruptChildren uses. + serverTransport_->interruptChildren(); + serverTransport_->interrupt(); +} + +void TServerFramework::newlyConnectedClient(const shared_ptr &pClient) +{ + { + // Synchronized sync(mon_); + ++clients_; + hwm_ = (std::max)(hwm_, clients_); + } + + onClientConnected(pClient); +} + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdelete-non-virtual-dtor" +void TServerFramework::disposeConnectedClient(TConnectedClient *pClient) +{ + onClientDisconnected(pClient); + delete pClient; + + // Synchronized sync(mon_); + if (limit_ - --clients_ > 0) { + // mon_.notify(); + } +} +#pragma GCC diagnostic pop + +} // namespace server +} // namespace thrift +} // namespace apache diff --git a/modules/thrift/src/thrift/server/TServerFramework.h b/modules/thrift/src/thrift/server/TServerFramework.h new file mode 100644 index 000000000000..298f31f3b460 --- /dev/null +++ b/modules/thrift/src/thrift/server/TServerFramework.h @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2006- Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSERVERFRAMEWORK_H_ +#define _THRIFT_SERVER_TSERVERFRAMEWORK_H_ 1 + +#include +#include +#include +#include +#include +#include +#include + +namespace apache +{ +namespace thrift +{ +namespace server +{ + +/** + * TServerFramework provides a single consolidated processing loop for + * servers. By having a single processing loop, behavior between servers + * is more predictable and maintenance cost is lowered. Implementations + * of TServerFramework must provide a method to deal with a client that + * connects and one that disconnects. + * + * While this functionality could be rolled directly into TServer, and + * probably should be, it would break the TServer interface contract so + * to maintain backwards compatibility for third party servers, no TServers + * were harmed in the making of this class. + */ +class TServerFramework : public TServer +{ +public: + TServerFramework( + const std::shared_ptr &processorFactory, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &transportFactory, + const std::shared_ptr &protocolFactory); + + TServerFramework( + const std::shared_ptr &processor, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &transportFactory, + const std::shared_ptr &protocolFactory); + + TServerFramework( + const std::shared_ptr &processorFactory, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &inputTransportFactory, + const std::shared_ptr + &outputTransportFactory, + const std::shared_ptr + &inputProtocolFactory, + const std::shared_ptr + &outputProtocolFactory); + + TServerFramework( + const std::shared_ptr &processor, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &inputTransportFactory, + const std::shared_ptr + &outputTransportFactory, + const std::shared_ptr + &inputProtocolFactory, + const std::shared_ptr + &outputProtocolFactory); + + ~TServerFramework(); + + /** + * Accept clients from the TServerTransport and add them for processing. + * Call stop() on another thread to interrupt processing + * and return control to the caller. + * Post-conditions (return guarantees): + * The serverTransport will be closed. + */ + virtual void serve() override; + + /** + * Interrupt serve() so that it meets post-conditions and returns. + */ + virtual void stop() override; + + /** + * Get the concurrent client limit. + * \returns the concurrent client limit + */ + virtual int64_t getConcurrentClientLimit() const; + + /** + * Get the number of currently connected clients. + * \returns the number of currently connected clients + */ + virtual int64_t getConcurrentClientCount() const; + + /** + * Get the highest number of concurrent clients. + * \returns the highest number of concurrent clients + */ + virtual int64_t getConcurrentClientCountHWM() const; + + /** + * Set the concurrent client limit. This can be changed while + * the server is serving however it will not necessarily be + * enforced until the next client is accepted and added. If the + * limit is lowered below the number of connected clients, no + * action is taken to disconnect the clients. + * The default value used if this is not called is INT64_MAX. + * \param[in] newLimit the new limit of concurrent clients + * \throws std::invalid_argument if newLimit is less than 1 + */ + virtual void setConcurrentClientLimit(int64_t newLimit); + +protected: + /** + * A client has connected. The implementation is responsible for managing the + * lifetime of the client object. This is called during the serve() thread, + * therefore a failure to return quickly will result in new client connection + * delays. + * + * \param[in] pClient the newly connected client + */ + virtual void onClientConnected(const std::shared_ptr &pClient) = 0; + + /** + * A client has disconnected. + * When called: + * The server no longer tracks the client. + * The client TTransport has already been closed. + * The implementation must not delete the pointer. + * + * \param[in] pClient the disconnected client + */ + virtual void onClientDisconnected(TConnectedClient *pClient) = 0; + +private: + /** + * Common handling for new connected clients. Implements concurrent + * client rate limiting after onClientConnected returns by blocking the + * serve() thread if the limit has been reached. + */ + void newlyConnectedClient(const std::shared_ptr &pClient); + + /** + * Smart pointer client deletion. + * Calls onClientDisconnected and then deletes pClient. + */ + void disposeConnectedClient(TConnectedClient *pClient); + + /** + * The number of concurrent clients. + */ + int64_t clients_; + + /** + * The high water mark of concurrent clients. + */ + int64_t hwm_; + + /** + * The limit on the number of concurrent clients. + */ + int64_t limit_; +}; +} // namespace server +} // namespace thrift +} // namespace apache + +#endif // #ifndef _THRIFT_SERVER_TSERVERFRAMEWORK_H_ diff --git a/modules/thrift/src/thrift/server/TSimpleServer.h b/modules/thrift/src/thrift/server/TSimpleServer.h new file mode 100644 index 000000000000..dc596dbef1f5 --- /dev/null +++ b/modules/thrift/src/thrift/server/TSimpleServer.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2006- Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ +#define _THRIFT_SERVER_TSIMPLESERVER_H_ 1 + +#include + +namespace apache +{ +namespace thrift +{ +namespace server +{ + +/** + * This is the most basic simple server. It is single-threaded and runs a + * continuous loop of accepting a single connection, processing requests on + * that connection until it closes, and then repeating. + */ +class TSimpleServer : public TServerFramework +{ +public: + TSimpleServer( + const std::shared_ptr &processorFactory, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &transportFactory, + const std::shared_ptr &protocolFactory); + + TSimpleServer( + const std::shared_ptr &processor, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &transportFactory, + const std::shared_ptr &protocolFactory); + + TSimpleServer( + const std::shared_ptr &processorFactory, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &inputTransportFactory, + const std::shared_ptr + &outputTransportFactory, + const std::shared_ptr + &inputProtocolFactory, + const std::shared_ptr + &outputProtocolFactory); + + TSimpleServer( + const std::shared_ptr &processor, + const std::shared_ptr &serverTransport, + const std::shared_ptr + &inputTransportFactory, + const std::shared_ptr + &outputTransportFactory, + const std::shared_ptr + &inputProtocolFactory, + const std::shared_ptr + &outputProtocolFactory); + + ~TSimpleServer(); + +protected: + void onClientConnected(const std::shared_ptr &pClient) override + /* override */; + void onClientDisconnected(TConnectedClient *pClient) override /* override */; + +private: + void setConcurrentClientLimit(int64_t newLimit) override; // hide +}; +} // namespace server +} // namespace thrift +} // namespace apache + +#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ diff --git a/modules/thrift/src/thrift/transport/TSSLServerSocket.cpp b/modules/thrift/src/thrift/transport/TSSLServerSocket.cpp new file mode 100644 index 000000000000..f4390e521546 --- /dev/null +++ b/modules/thrift/src/thrift/transport/TSSLServerSocket.cpp @@ -0,0 +1,244 @@ +/* + * Copyright 2006 Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +template inline void *cast_sockopt(T *v) +{ + return reinterpret_cast(v); +} + +void destroyer_of_fine_sockets(THRIFT_SOCKET *ssock); + +namespace apache +{ +namespace thrift +{ +namespace transport +{ + +/** + * SSL server socket implementation. + */ +TSSLServerSocket::TSSLServerSocket(int port, std::shared_ptr factory) + : TServerSocket(port), factory_(factory) +{ + factory_->server(true); +} + +TSSLServerSocket::TSSLServerSocket(const std::string &address, int port, + std::shared_ptr factory) + : TServerSocket(address, port), factory_(factory) +{ + factory_->server(true); +} + +TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout, + std::shared_ptr factory) + : TServerSocket(port, sendTimeout, recvTimeout), factory_(factory) +{ + factory_->server(true); +} + +std::shared_ptr TSSLServerSocket::createSocket(THRIFT_SOCKET client) +{ + if (interruptableChildren_) { + return factory_->createSocket(client, pChildInterruptSockReader_); + + } else { + return factory_->createSocket(client); + } +} + +void TSSLServerSocket::listen() +{ + THRIFT_SOCKET sv[2]; + // Create the socket pair used to interrupt + if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TServerSocket::listen() socketpair() interrupt", + THRIFT_GET_SOCKET_ERROR); + interruptSockWriter_ = THRIFT_INVALID_SOCKET; + interruptSockReader_ = THRIFT_INVALID_SOCKET; + } else { + interruptSockWriter_ = sv[1]; + interruptSockReader_ = sv[0]; + } + + // Create the socket pair used to interrupt all clients + if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TServerSocket::listen() socketpair() childInterrupt", + THRIFT_GET_SOCKET_ERROR); + childInterruptSockWriter_ = THRIFT_INVALID_SOCKET; + pChildInterruptSockReader_.reset(); + } else { + childInterruptSockWriter_ = sv[1]; + pChildInterruptSockReader_ = std::shared_ptr( + new THRIFT_SOCKET(sv[0]), destroyer_of_fine_sockets); + } + + // Validate port number + if (port_ < 0 || port_ > 0xFFFF) { + throw TTransportException(TTransportException::BAD_ARGS, + "Specified port is invalid"); + } + + // Resolve host:port strings into an iterable of struct addrinfo* + AddressResolutionHelper resolved_addresses; + try { + resolved_addresses.resolve(address_, std::to_string(port_), SOCK_STREAM, + AI_PASSIVE | AI_V4MAPPED); + + } catch (const std::system_error &e) { + GlobalOutput.printf("getaddrinfo() -> %d; %s", e.code().value(), e.what()); + close(); + throw TTransportException(TTransportException::NOT_OPEN, + "Could not resolve host for server socket."); + } + + // we may want to try to bind more than once, since THRIFT_NO_SOCKET_CACHING doesn't + // always seem to work. The client can configure the retry variables. + int retries = 0; + int errno_copy = 0; + + // -- TCP socket -- // + + auto addr_iter = AddressResolutionHelper::Iter{}; + + // Via DNS or somehow else, single hostname can resolve into many addresses. + // Results may contain perhaps a mix of IPv4 and IPv6. Here, we iterate + // over what system gave us, picking the first address that works. + do { + if (!addr_iter) { + // init + recycle over many retries + addr_iter = resolved_addresses.iterate(); + } + auto trybind = *addr_iter++; + + serverSocket_ = socket(trybind->ai_family, trybind->ai_socktype, IPPROTO_TLS_1_2); + if (serverSocket_ == -1) { + errno_copy = THRIFT_GET_SOCKET_ERROR; + continue; + } + + _setup_sockopts(); + _setup_tcp_sockopts(); + + static const sec_tag_t sec_tag_list[3] = { + Thrift_TLS_CA_CERT_TAG, Thrift_TLS_SERVER_CERT_TAG, Thrift_TLS_PRIVATE_KEY}; + + int ret = setsockopt(serverSocket_, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_list, + sizeof(sec_tag_list)); + if (ret != 0) { + throw TTransportException(TTransportException::NOT_OPEN, + "set TLS_SEC_TAG_LIST failed"); + } + +#ifdef IPV6_V6ONLY + if (trybind->ai_family == AF_INET6) { + int zero = 0; + if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY, + cast_sockopt(&zero), sizeof(zero))) { + GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ", + THRIFT_GET_SOCKET_ERROR); + } + } +#endif // #ifdef IPV6_V6ONLY + + if (0 == ::bind(serverSocket_, trybind->ai_addr, + static_cast(trybind->ai_addrlen))) { + break; + } + errno_copy = THRIFT_GET_SOCKET_ERROR; + + // use short circuit evaluation here to only sleep if we need to + } while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0)); + + // retrieve bind info + if (port_ == 0 && retries <= retryLimit_) { + struct sockaddr_storage sa; + socklen_t len = sizeof(sa); + std::memset(&sa, 0, len); + if (::getsockname(serverSocket_, reinterpret_cast(&sa), &len) < + 0) { + errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy); + } else { + if (sa.ss_family == AF_INET6) { + const auto *sin = + reinterpret_cast(&sa); + port_ = ntohs(sin->sin6_port); + } else { + const auto *sin = reinterpret_cast(&sa); + port_ = ntohs(sin->sin_port); + } + } + } + + // throw error if socket still wasn't created successfully + if (serverSocket_ == THRIFT_INVALID_SOCKET) { + GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, + "Could not create server socket.", errno_copy); + } + + // throw an error if we failed to bind properly + if (retries > retryLimit_) { + char errbuf[1024]; + + THRIFT_SNPRINTF(errbuf, sizeof(errbuf), + "TServerSocket::listen() Could not bind to port %d", port_); + + GlobalOutput(errbuf); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not bind", + errno_copy); + } + + if (listenCallback_) { + listenCallback_(serverSocket_); + } + + // Call listen + if (-1 == ::listen(serverSocket_, acceptBacklog_)) { + errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not listen", + errno_copy); + } + + // The socket is now listening! + listening_ = true; +} + +void TSSLServerSocket::close() +{ + rwMutex_.lock(); + if (pChildInterruptSockReader_ != nullptr && + *pChildInterruptSockReader_ != THRIFT_INVALID_SOCKET) { + ::THRIFT_CLOSESOCKET(*pChildInterruptSockReader_); + *pChildInterruptSockReader_ = THRIFT_INVALID_SOCKET; + } + + rwMutex_.unlock(); + + TServerSocket::close(); +} + +} // namespace transport +} // namespace thrift +} // namespace apache diff --git a/modules/thrift/src/thrift/transport/TSSLServerSocket.h b/modules/thrift/src/thrift/transport/TSSLServerSocket.h new file mode 100644 index 000000000000..582741b97355 --- /dev/null +++ b/modules/thrift/src/thrift/transport/TSSLServerSocket.h @@ -0,0 +1,67 @@ +/* + * Copyright 2006 Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_ +#define _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_ 1 + +#include + +namespace apache +{ +namespace thrift +{ +namespace transport +{ + +class TSSLSocketFactory; + +/** + * Server socket that accepts SSL connections. + */ +class TSSLServerSocket : public TServerSocket +{ +public: + /** + * Constructor. Binds to all interfaces. + * + * @param port Listening port + * @param factory SSL socket factory implementation + */ + TSSLServerSocket(int port, std::shared_ptr factory); + + /** + * Constructor. Binds to the specified address. + * + * @param address Address to bind to + * @param port Listening port + * @param factory SSL socket factory implementation + */ + TSSLServerSocket(const std::string &address, int port, + std::shared_ptr factory); + + /** + * Constructor. Binds to all interfaces. + * + * @param port Listening port + * @param sendTimeout Socket send timeout + * @param recvTimeout Socket receive timeout + * @param factory SSL socket factory implementation + */ + TSSLServerSocket(int port, int sendTimeout, int recvTimeout, + std::shared_ptr factory); + + void listen() override; + void close() override; + +protected: + std::shared_ptr createSocket(THRIFT_SOCKET socket) override; + std::shared_ptr factory_; +}; +} // namespace transport +} // namespace thrift +} // namespace apache + +#endif diff --git a/modules/thrift/src/thrift/transport/TSSLSocket.cpp b/modules/thrift/src/thrift/transport/TSSLSocket.cpp new file mode 100644 index 000000000000..3ac178ede3be --- /dev/null +++ b/modules/thrift/src/thrift/transport/TSSLSocket.cpp @@ -0,0 +1,656 @@ +/* + * Copyright 2006 Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#ifdef HAVE_ARPA_INET_H +#include +#endif +#include +#ifdef HAVE_POLL_H +#include +#endif + +#include + +#include + +#include +#include +#include +#include +#include + +using namespace apache::thrift::concurrency; +using std::string; + +struct CRYPTO_dynlock_value { + Mutex mutex; +}; + +namespace apache +{ +namespace thrift +{ +namespace transport +{ + +static bool matchName(const char *host, const char *pattern, int size); +static char uppercase(char c); + +// TSSLSocket implementation +TSSLSocket::TSSLSocket(std::shared_ptr ctx, std::shared_ptr config) + : TSocket(config), server_(false), ctx_(ctx) +{ + init(); +} + +TSSLSocket::TSSLSocket(std::shared_ptr ctx, + std::shared_ptr interruptListener, + std::shared_ptr config) + : TSocket(config), server_(false), ctx_(ctx) +{ + init(); + interruptListener_ = interruptListener; +} + +TSSLSocket::TSSLSocket(std::shared_ptr ctx, THRIFT_SOCKET socket, + std::shared_ptr config) + : TSocket(socket, config), server_(false), ctx_(ctx) +{ + init(); +} + +TSSLSocket::TSSLSocket(std::shared_ptr ctx, THRIFT_SOCKET socket, + std::shared_ptr interruptListener, + std::shared_ptr config) + : TSocket(socket, interruptListener, config), server_(false), ctx_(ctx) +{ + init(); +} + +TSSLSocket::TSSLSocket(std::shared_ptr ctx, string host, int port, + std::shared_ptr config) + : TSocket(host, port, config), server_(false), ctx_(ctx) +{ + init(); +} + +TSSLSocket::TSSLSocket(std::shared_ptr ctx, string host, int port, + std::shared_ptr interruptListener, + std::shared_ptr config) + : TSocket(host, port, config), server_(false), ctx_(ctx) +{ + init(); + interruptListener_ = interruptListener; +} + +TSSLSocket::~TSSLSocket() +{ + close(); +} + +template inline void *cast_sockopt(T *v) +{ + return reinterpret_cast(v); +} + +void TSSLSocket::authorize() +{ +} + +void TSSLSocket::openSecConnection(struct addrinfo *res) +{ + socket_ = socket(res->ai_family, res->ai_socktype, ctx_->protocol); + + if (socket_ == THRIFT_INVALID_SOCKET) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy); + } + + static const sec_tag_t sec_tag_list[3] = { + Thrift_TLS_CA_CERT_TAG, Thrift_TLS_SERVER_CERT_TAG, Thrift_TLS_PRIVATE_KEY}; + + int ret = + setsockopt(socket_, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_list, sizeof(sec_tag_list)); + if (ret != 0) { + throw TTransportException(TTransportException::NOT_OPEN, + "set TLS_SEC_TAG_LIST failed"); + } + + ret = setsockopt(socket_, SOL_TLS, TLS_PEER_VERIFY, &(ctx_->verifyMode), + sizeof(ctx_->verifyMode)); + if (ret != 0) { + throw TTransportException(TTransportException::NOT_OPEN, + "set TLS_PEER_VERIFY failed"); + } + + ret = setsockopt(socket_, SOL_TLS, TLS_HOSTNAME, host_.c_str(), host_.size()); + if (ret != 0) { + throw TTransportException(TTransportException::NOT_OPEN, "set TLS_HOSTNAME failed"); + } + + // Send timeout + if (sendTimeout_ > 0) { + setSendTimeout(sendTimeout_); + } + + // Recv timeout + if (recvTimeout_ > 0) { + setRecvTimeout(recvTimeout_); + } + + if (keepAlive_) { + setKeepAlive(keepAlive_); + } + + // Linger + setLinger(lingerOn_, lingerVal_); + + // No delay + setNoDelay(noDelay_); + +#ifdef SO_NOSIGPIPE + { + int one = 1; + setsockopt(socket_, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(one)); + } +#endif + +// Uses a low min RTO if asked to. +#ifdef TCP_LOW_MIN_RTO + if (getUseLowMinRto()) { + int one = 1; + setsockopt(socket_, IPPROTO_TCP, TCP_LOW_MIN_RTO, &one, sizeof(one)); + } +#endif + + // Set the socket to be non blocking for connect if a timeout exists + int flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0); + if (connTimeout_ > 0) { + if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK)) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSocket::open() THRIFT_FCNTL() " + getSocketInfo(), + errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, + "THRIFT_FCNTL() failed", errno_copy); + } + } else { + if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags & ~THRIFT_O_NONBLOCK)) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(), + errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, + "THRIFT_FCNTL() failed", errno_copy); + } + } + + // Connect the socket + + ret = connect(socket_, res->ai_addr, static_cast(res->ai_addrlen)); + + // success case + if (ret == 0) { + goto done; + } + + if ((THRIFT_GET_SOCKET_ERROR != THRIFT_EINPROGRESS) && + (THRIFT_GET_SOCKET_ERROR != THRIFT_EWOULDBLOCK)) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", + errno_copy); + } + + struct THRIFT_POLLFD fds[1]; + std::memset(fds, 0, sizeof(fds)); + fds[0].fd = socket_; + fds[0].events = THRIFT_POLLOUT; + ret = THRIFT_POLL(fds, 1, connTimeout_); + + if (ret > 0) { + // Ensure the socket is connected and that there are no errors set + int val; + socklen_t lon; + lon = sizeof(int); + int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, cast_sockopt(&val), &lon); + if (ret2 == -1) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(), + errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()", + errno_copy); + } + // no errors on socket, go to town + if (val == 0) { + goto done; + } + GlobalOutput.perror("TSocket::open() error on socket (after THRIFT_POLL) " + + getSocketInfo(), + val); + throw TTransportException(TTransportException::NOT_OPEN, "socket open() error", + val); + } else if (ret == 0) { + // socket timed out + string errStr = "TSocket::open() timed out " + getSocketInfo(); + GlobalOutput(errStr.c_str()); + throw TTransportException(TTransportException::NOT_OPEN, "open() timed out"); + } else { + // error on THRIFT_POLL() + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSocket::open() THRIFT_POLL() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_POLL() failed", + errno_copy); + } + +done: + // Set socket back to normal mode (blocking) + if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags)) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_FCNTL() failed", + errno_copy); + } + + setCachedAddress(res->ai_addr, static_cast(res->ai_addrlen)); +} + +void TSSLSocket::init() +{ + handshakeCompleted_ = false; + readRetryCount_ = 0; + eventSafe_ = false; +} + +void TSSLSocket::open() +{ + if (isOpen() || server()) { + throw TTransportException(TTransportException::BAD_ARGS); + } + + // Validate port number + if (port_ < 0 || port_ > 0xFFFF) { + throw TTransportException(TTransportException::BAD_ARGS, + "Specified port is invalid"); + } + + struct addrinfo hints, *res, *res0; + res = nullptr; + res0 = nullptr; + int error; + char port[sizeof("65535")]; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + error = getaddrinfo(host_.c_str(), port, &hints, &res0); + + if (error == DNS_EAI_NODATA) { + hints.ai_flags &= ~AI_ADDRCONFIG; + error = getaddrinfo(host_.c_str(), port, &hints, &res0); + } + + if (error) { + string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo() + + string(THRIFT_GAI_STRERROR(error)); + GlobalOutput(errStr.c_str()); + close(); + throw TTransportException(TTransportException::NOT_OPEN, + "Could not resolve host for client socket."); + } + + // Cycle through all the returned addresses until one + // connects or push the exception up. + for (res = res0; res; res = res->ai_next) { + try { + openSecConnection(res); + break; + } catch (TTransportException &) { + if (res->ai_next) { + close(); + } else { + close(); + freeaddrinfo(res0); // cleanup on failure + throw; + } + } + } + + // Free address structure memory + freeaddrinfo(res0); +} + +TSSLSocketFactory::TSSLSocketFactory(SSLProtocol protocol) + : ctx_(std::make_shared()), server_(false) +{ + switch (protocol) { + case SSLTLS: + break; + case TLSv1_0: + break; + case TLSv1_1: + ctx_->protocol = IPPROTO_TLS_1_1; + break; + case TLSv1_2: + ctx_->protocol = IPPROTO_TLS_1_2; + break; + default: + throw TTransportException(TTransportException::BAD_ARGS, + "Specified protocol is invalid"); + } +} + +TSSLSocketFactory::~TSSLSocketFactory() +{ +} + +std::shared_ptr TSSLSocketFactory::createSocket() +{ + std::shared_ptr ssl(new TSSLSocket(ctx_)); + setup(ssl); + return ssl; +} + +std::shared_ptr +TSSLSocketFactory::createSocket(std::shared_ptr interruptListener) +{ + std::shared_ptr ssl(new TSSLSocket(ctx_, interruptListener)); + setup(ssl); + return ssl; +} + +std::shared_ptr TSSLSocketFactory::createSocket(THRIFT_SOCKET socket) +{ + std::shared_ptr ssl(new TSSLSocket(ctx_, socket)); + setup(ssl); + return ssl; +} + +std::shared_ptr +TSSLSocketFactory::createSocket(THRIFT_SOCKET socket, + std::shared_ptr interruptListener) +{ + std::shared_ptr ssl(new TSSLSocket(ctx_, socket, interruptListener)); + setup(ssl); + return ssl; +} + +std::shared_ptr TSSLSocketFactory::createSocket(const string &host, int port) +{ + std::shared_ptr ssl(new TSSLSocket(ctx_, host, port)); + setup(ssl); + return ssl; +} + +std::shared_ptr +TSSLSocketFactory::createSocket(const string &host, int port, + std::shared_ptr interruptListener) +{ + std::shared_ptr ssl(new TSSLSocket(ctx_, host, port, interruptListener)); + setup(ssl); + return ssl; +} + +static void tlsCredtErrMsg(string &errors, const int status); + +void TSSLSocketFactory::setup(std::shared_ptr ssl) +{ + ssl->server(server()); + if (access_ == nullptr && !server()) { + access_ = std::shared_ptr(new DefaultClientAccessManager); + } + if (access_ != nullptr) { + ssl->access(access_); + } +} + +void TSSLSocketFactory::ciphers(const string &enable) +{ +} + +void TSSLSocketFactory::authenticate(bool required) +{ + if (required) { + ctx_->verifyMode = TLS_PEER_VERIFY_REQUIRED; + } else { + ctx_->verifyMode = TLS_PEER_VERIFY_NONE; + } +} + +void TSSLSocketFactory::loadCertificate(const char *path, const char *format) +{ + if (path == nullptr || format == nullptr) { + throw TTransportException( + TTransportException::BAD_ARGS, + "loadCertificateChain: either or is nullptr"); + } + if (strcmp(format, "PEM") == 0) { + + } else { + throw TSSLException("Unsupported certificate format: " + string(format)); + } +} + +void TSSLSocketFactory::loadCertificateFromBuffer(const char *aCertificate, const char *format) +{ + if (aCertificate == nullptr || format == nullptr) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadCertificate: either or is nullptr"); + } + + if (strcmp(format, "PEM") == 0) { + const int status = tls_credential_add(Thrift_TLS_SERVER_CERT_TAG, + TLS_CREDENTIAL_SERVER_CERTIFICATE, + aCertificate, strlen(aCertificate) + 1); + + if (status != 0) { + string errors; + tlsCredtErrMsg(errors, status); + throw TSSLException("tls_credential_add: " + errors); + } + } else { + throw TSSLException("Unsupported certificate format: " + string(format)); + } +} + +void TSSLSocketFactory::loadPrivateKey(const char *path, const char *format) +{ + if (path == nullptr || format == nullptr) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadPrivateKey: either or is nullptr"); + } + if (strcmp(format, "PEM") == 0) { + if (0) { + string errors; + // tlsCredtErrMsg(errors, status); + throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors); + } + } +} + +void TSSLSocketFactory::loadPrivateKeyFromBuffer(const char *aPrivateKey, const char *format) +{ + if (aPrivateKey == nullptr || format == nullptr) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadPrivateKey: either or is nullptr"); + } + if (strcmp(format, "PEM") == 0) { + const int status = + tls_credential_add(Thrift_TLS_PRIVATE_KEY, TLS_CREDENTIAL_PRIVATE_KEY, + aPrivateKey, strlen(aPrivateKey) + 1); + + if (status != 0) { + string errors; + tlsCredtErrMsg(errors, status); + throw TSSLException("SSL_CTX_use_PrivateKey: " + errors); + } + } else { + throw TSSLException("Unsupported certificate format: " + string(format)); + } +} + +void TSSLSocketFactory::loadTrustedCertificates(const char *path, const char *capath) +{ + if (path == nullptr) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadTrustedCertificates: is nullptr"); + } + if (0) { + string errors; + // tlsCredtErrMsg(errors, status); + throw TSSLException("SSL_CTX_load_verify_locations: " + errors); + } +} + +void TSSLSocketFactory::loadTrustedCertificatesFromBuffer(const char *aCertificate, + const char *aChain) +{ + if (aCertificate == nullptr) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadTrustedCertificates: aCertificate is empty"); + } + const int status = tls_credential_add(Thrift_TLS_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE, + aCertificate, strlen(aCertificate) + 1); + + if (status != 0) { + string errors; + tlsCredtErrMsg(errors, status); + throw TSSLException("X509_STORE_add_cert: " + errors); + } + + if (aChain) { + } +} + +void TSSLSocketFactory::randomize() +{ +} + +void TSSLSocketFactory::overrideDefaultPasswordCallback() +{ +} + +void TSSLSocketFactory::server(bool flag) +{ + server_ = flag; + ctx_->verifyMode = TLS_PEER_VERIFY_NONE; +} + +bool TSSLSocketFactory::server() const +{ + return server_; +} + +int TSSLSocketFactory::passwordCallback(char *password, int size, int, void *data) +{ + auto *factory = (TSSLSocketFactory *)data; + string userPassword; + factory->getPassword(userPassword, size); + int length = static_cast(userPassword.size()); + if (length > size) { + length = size; + } + strncpy(password, userPassword.c_str(), length); + userPassword.assign(userPassword.size(), '*'); + return length; +} + +// extract error messages from error queue +static void tlsCredtErrMsg(string &errors, const int status) +{ + if (status == EACCES) { + errors = "Access to the TLS credential subsystem was denied"; + } else if (status == ENOMEM) { + errors = "Not enough memory to add new TLS credential"; + } else if (status == EEXIST) { + errors = "TLS credential of specific tag and type already exists"; + } else { + errors = "Unknown error"; + } +} + +/** + * Default implementation of AccessManager + */ +Decision DefaultClientAccessManager::verify(const sockaddr_storage &sa) noexcept +{ + (void)sa; + return SKIP; +} + +Decision DefaultClientAccessManager::verify(const string &host, const char *name, int size) noexcept +{ + if (host.empty() || name == nullptr || size <= 0) { + return SKIP; + } + return (matchName(host.c_str(), name, size) ? ALLOW : SKIP); +} + +Decision DefaultClientAccessManager::verify(const sockaddr_storage &sa, const char *data, + int size) noexcept +{ + bool match = false; + if (sa.ss_family == AF_INET && size == sizeof(in_addr)) { + match = (memcmp(&((sockaddr_in *)&sa)->sin_addr, data, size) == 0); + } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) { + match = (memcmp(&((sockaddr_in6 *)&sa)->sin6_addr, data, size) == 0); + } + return (match ? ALLOW : SKIP); +} + +/** + * Match a name with a pattern. The pattern may include wildcard. A single + * wildcard "*" can match up to one component in the domain name. + * + * @param host Host name, typically the name of the remote host + * @param pattern Name retrieved from certificate + * @param size Size of "pattern" + * @return True, if "host" matches "pattern". False otherwise. + */ +bool matchName(const char *host, const char *pattern, int size) +{ + bool match = false; + int i = 0, j = 0; + while (i < size && host[j] != '\0') { + if (uppercase(pattern[i]) == uppercase(host[j])) { + i++; + j++; + continue; + } + if (pattern[i] == '*') { + while (host[j] != '.' && host[j] != '\0') { + j++; + } + i++; + continue; + } + break; + } + if (i == size && host[j] == '\0') { + match = true; + } + return match; +} + +// This is to work around the Turkish locale issue, i.e., +// toupper('i') != toupper('I') if locale is "tr_TR" +char uppercase(char c) +{ + if ('a' <= c && c <= 'z') { + return c + ('A' - 'a'); + } + return c; +} +} // namespace transport +} // namespace thrift +} // namespace apache diff --git a/modules/thrift/src/thrift/transport/TSSLSocket.h b/modules/thrift/src/thrift/transport/TSSLSocket.h new file mode 100644 index 000000000000..db6fa359183c --- /dev/null +++ b/modules/thrift/src/thrift/transport/TSSLSocket.h @@ -0,0 +1,465 @@ +/* + * Copyright 2006 Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_ +#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1 + +// Put this first to avoid WIN32 build failure +#include + +#include +#include + +#include + +namespace apache +{ +namespace thrift +{ +namespace transport +{ + +class AccessManager; +class SSLContext; + +enum SSLProtocol { + SSLTLS = 0, // Supports SSLv2 and SSLv3 handshake but only negotiates at TLSv1_0 or later. + // SSLv2 = 1, // HORRIBLY INSECURE! + SSLv3 = 2, // Supports SSLv3 only - also horribly insecure! + TLSv1_0 = 3, // Supports TLSv1_0 or later. + TLSv1_1 = 4, // Supports TLSv1_1 or later. + TLSv1_2 = 5, // Supports TLSv1_2 or later. + LATEST = TLSv1_2 +}; + +#define TSSL_EINTR 0 +#define TSSL_DATA 1 + +/** + * Initialize OpenSSL library. This function, or some other + * equivalent function to initialize OpenSSL, must be called before + * TSSLSocket is used. If you set TSSLSocketFactory to use manual + * OpenSSL initialization, you should call this function or otherwise + * ensure OpenSSL is initialized yourself. + */ +void initializeOpenSSL(); +/** + * Cleanup OpenSSL library. This function should be called to clean + * up OpenSSL after use of OpenSSL functionality is finished. If you + * set TSSLSocketFactory to use manual OpenSSL initialization, you + * should call this function yourself or ensure that whatever + * initialized OpenSSL cleans it up too. + */ +void cleanupOpenSSL(); + +/** + * OpenSSL implementation for SSL socket interface. + */ +class TSSLSocket : public TSocket +{ +public: + ~TSSLSocket() override; + /** + * TTransport interface. + */ + void open() override; + /** + * Set whether to use client or server side SSL handshake protocol. + * + * @param flag Use server side handshake protocol if true. + */ + void server(bool flag) + { + server_ = flag; + } + /** + * Determine whether the SSL socket is server or client mode. + */ + bool server() const + { + return server_; + } + /** + * Set AccessManager. + * + * @param manager Instance of AccessManager + */ + virtual void access(std::shared_ptr manager) + { + access_ = manager; + } + /** + * Set eventSafe flag if libevent is used. + */ + void setLibeventSafe() + { + eventSafe_ = true; + } + /** + * Determines whether SSL Socket is libevent safe or not. + */ + bool isLibeventSafe() const + { + return eventSafe_; + } + + void authenticate(bool required); + +protected: + /** + * Constructor. + */ + TSSLSocket(std::shared_ptr ctx, + std::shared_ptr config = nullptr); + /** + * Constructor with an interrupt signal. + */ + TSSLSocket(std::shared_ptr ctx, + std::shared_ptr interruptListener, + std::shared_ptr config = nullptr); + /** + * Constructor, create an instance of TSSLSocket given an existing socket. + * + * @param socket An existing socket + */ + TSSLSocket(std::shared_ptr ctx, THRIFT_SOCKET socket, + std::shared_ptr config = nullptr); + /** + * Constructor, create an instance of TSSLSocket given an existing socket that can be + * interrupted. + * + * @param socket An existing socket + */ + TSSLSocket(std::shared_ptr ctx, THRIFT_SOCKET socket, + std::shared_ptr interruptListener, + std::shared_ptr config = nullptr); + /** + * Constructor. + * + * @param host Remote host name + * @param port Remote port number + */ + TSSLSocket(std::shared_ptr ctx, std::string host, int port, + std::shared_ptr config = nullptr); + /** + * Constructor with an interrupt signal. + * + * @param host Remote host name + * @param port Remote port number + */ + TSSLSocket(std::shared_ptr ctx, std::string host, int port, + std::shared_ptr interruptListener, + std::shared_ptr config = nullptr); + /** + * Authorize peer access after SSL handshake completes. + */ + virtual void authorize(); + /** + * Initiate SSL handshake if not already initiated. + */ + void initializeHandshake(); + /** + * Initiate SSL handshake params. + */ + void initializeHandshakeParams(); + /** + * Check if SSL handshake is completed or not. + */ + bool checkHandshake(); + /** + * Waits for an socket or shutdown event. + * + * @throw TTransportException::INTERRUPTED if interrupted is signaled. + * + * @return TSSL_EINTR if EINTR happened on the underlying socket + * TSSL_DATA if data is available on the socket. + */ + unsigned int waitForEvent(bool wantRead); + + void openSecConnection(struct addrinfo *res); + + bool server_; + std::shared_ptr ctx_; + std::shared_ptr access_; + friend class TSSLSocketFactory; + +private: + bool handshakeCompleted_; + int readRetryCount_; + bool eventSafe_; + + void init(); +}; + +/** + * SSL socket factory. SSL sockets should be created via SSL factory. + * The factory will automatically initialize and cleanup openssl as long as + * there is a TSSLSocketFactory instantiated, and as long as the static + * boolean manualOpenSSLInitialization_ is set to false, the default. + * + * If you would like to initialize and cleanup openssl yourself, set + * manualOpenSSLInitialization_ to true and TSSLSocketFactory will no + * longer be responsible for openssl initialization and teardown. + * + * It is the responsibility of the code using TSSLSocketFactory to + * ensure that the factory lifetime exceeds the lifetime of any sockets + * it might create. If this is not guaranteed, a socket may call into + * openssl after the socket factory has cleaned up openssl! This + * guarantee is unnecessary if manualOpenSSLInitialization_ is true, + * however, since it would be up to the consuming application instead. + */ +class TSSLSocketFactory +{ + public: + /** + * Constructor/Destructor + * + * @param protocol The SSL/TLS protocol to use. + */ + TSSLSocketFactory(SSLProtocol protocol = SSLTLS); + virtual ~TSSLSocketFactory(); + /** + * Create an instance of TSSLSocket with a fresh new socket. + */ + virtual std::shared_ptr createSocket(); + /** + * Create an instance of TSSLSocket with a fresh new socket, which is interruptable. + */ + virtual std::shared_ptr + createSocket(std::shared_ptr interruptListener); + /** + * Create an instance of TSSLSocket with the given socket. + * + * @param socket An existing socket. + */ + virtual std::shared_ptr createSocket(THRIFT_SOCKET socket); + /** + * Create an instance of TSSLSocket with the given socket which is interruptable. + * + * @param socket An existing socket. + */ + virtual std::shared_ptr + createSocket(THRIFT_SOCKET socket, std::shared_ptr interruptListener); + /** + * Create an instance of TSSLSocket. + * + * @param host Remote host to be connected to + * @param port Remote port to be connected to + */ + virtual std::shared_ptr createSocket(const std::string &host, int port); + /** + * Create an instance of TSSLSocket. + * + * @param host Remote host to be connected to + * @param port Remote port to be connected to + */ + virtual std::shared_ptr + createSocket(const std::string &host, int port, + std::shared_ptr interruptListener); + /** + * Set ciphers to be used in SSL handshake process. + * + * @param ciphers A list of ciphers + */ + virtual void ciphers(const std::string &enable); + /** + * Enable/Disable authentication. + * + * @param required Require peer to present valid certificate if true + */ + virtual void authenticate(bool required); + /** + * Load server certificate. + * + * @param path Path to the certificate file + * @param format Certificate file format + */ + virtual void loadCertificate(const char *path, const char *format = "PEM"); + virtual void loadCertificateFromBuffer(const char *aCertificate, + const char *format = "PEM"); + /** + * Load private key. + * + * @param path Path to the private key file + * @param format Private key file format + */ + virtual void loadPrivateKey(const char *path, const char *format = "PEM"); + virtual void loadPrivateKeyFromBuffer(const char *aPrivateKey, const char *format = "PEM"); + /** + * Load trusted certificates from specified file. + * + * @param path Path to trusted certificate file + */ + virtual void loadTrustedCertificates(const char *path, const char *capath = nullptr); + virtual void loadTrustedCertificatesFromBuffer(const char *aCertificate, + const char *aChain = nullptr); + /** + * Default randomize method. + */ + virtual void randomize(); + /** + * Override default OpenSSL password callback with getPassword(). + */ + void overrideDefaultPasswordCallback(); + /** + * Set/Unset server mode. + * + * @param flag Server mode if true + */ + virtual void server(bool flag); + /** + * Determine whether the socket is in server or client mode. + * + * @return true, if server mode, or, false, if client mode + */ + virtual bool server() const; + /** + * Set AccessManager. + * + * @param manager The AccessManager instance + */ + virtual void access(std::shared_ptr manager) + { + access_ = manager; + } + static void setManualOpenSSLInitialization(bool manualOpenSSLInitialization) + { + manualOpenSSLInitialization_ = manualOpenSSLInitialization; + } + + protected: + std::shared_ptr ctx_; + + /** + * Override this method for custom password callback. It may be called + * multiple times at any time during a session as necessary. + * + * @param password Pass collected password to OpenSSL + * @param size Maximum length of password including NULL character + */ + virtual void getPassword(std::string & /* password */, int /* size */) + { + } + + private: + bool server_; + std::shared_ptr access_; + static concurrency::Mutex mutex_; + static uint64_t count_; + THRIFT_EXPORT static bool manualOpenSSLInitialization_; + + void setup(std::shared_ptr ssl); + static int passwordCallback(char *password, int size, int, void *data); +}; + +/** + * SSL exception. + */ +class TSSLException : public TTransportException +{ + public: + TSSLException(const std::string &message) + : TTransportException(TTransportException::INTERNAL_ERROR, message) + { + } + + const char *what() const noexcept override + { + if (message_.empty()) { + return "TSSLException"; + } else { + return message_.c_str(); + } + } +}; + +struct SSLContext { + int verifyMode = TLS_PEER_VERIFY_REQUIRED; + net_ip_protocol_secure protocol = IPPROTO_TLS_1_0; +}; + +/** + * Callback interface for access control. It's meant to verify the remote host. + * It's constructed when application starts and set to TSSLSocketFactory + * instance. It's passed onto all TSSLSocket instances created by this factory + * object. + */ +class AccessManager +{ + public: + enum Decision { + DENY = -1, // deny access + SKIP = 0, // cannot make decision, move on to next (if any) + ALLOW = 1 // allow access + }; + /** + * Destructor + */ + virtual ~AccessManager() = default; + /** + * Determine whether the peer should be granted access or not. It's called + * once after the SSL handshake completes successfully, before peer certificate + * is examined. + * + * If a valid decision (ALLOW or DENY) is returned, the peer certificate is + * not to be verified. + * + * @param sa Peer IP address + * @return True if the peer is trusted, false otherwise + */ + virtual Decision verify(const sockaddr_storage & /* sa */) noexcept + { + return DENY; + } + /** + * Determine whether the peer should be granted access or not. It's called + * every time a DNS subjectAltName/common name is extracted from peer's + * certificate. + * + * @param host Client mode: host name returned by TSocket::getHost() + * Server mode: host name returned by TSocket::getPeerHost() + * @param name SubjectAltName or common name extracted from peer certificate + * @param size Length of name + * @return True if the peer is trusted, false otherwise + * + * Note: The "name" parameter may be UTF8 encoded. + */ + virtual Decision verify(const std::string & /* host */, const char * /* name */, + int /* size */) noexcept + { + return DENY; + } + /** + * Determine whether the peer should be granted access or not. It's called + * every time an IP subjectAltName is extracted from peer's certificate. + * + * @param sa Peer IP address retrieved from the underlying socket + * @param data IP address extracted from certificate + * @param size Length of the IP address + * @return True if the peer is trusted, false otherwise + */ + virtual Decision verify(const sockaddr_storage & /* sa */, const char * /* data */, + int /* size */) noexcept + { + return DENY; + } +}; + +typedef AccessManager::Decision Decision; + +class DefaultClientAccessManager : public AccessManager +{ + public: + // AccessManager interface + Decision verify(const sockaddr_storage &sa) noexcept override; + Decision verify(const std::string &host, const char *name, int size) noexcept override; + Decision verify(const sockaddr_storage &sa, const char *data, int size) noexcept override; +}; +} // namespace transport +} // namespace thrift +} // namespace apache + +#endif diff --git a/modules/thrift/src/thrift/transport/TServerSocket.h b/modules/thrift/src/thrift/transport/TServerSocket.h new file mode 100644 index 000000000000..f213bd7bcfa7 --- /dev/null +++ b/modules/thrift/src/thrift/transport/TServerSocket.h @@ -0,0 +1,186 @@ +/* + * Copyright 2006 Facebook + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ +#define _THRIFT_TRANSPORT_TSERVERSOCKET_H_ 1 + +#include + +#include +#include +#include + +#include +#ifdef HAVE_SYS_SOCKET_H +#include +#endif +#ifdef HAVE_NETDB_H +#include +#endif + +namespace apache +{ +namespace thrift +{ +namespace transport +{ + +class TSocket; + +/** + * Server socket implementation of TServerTransport. Wrapper around a unix + * socket listen and accept calls. + * + */ +class TServerSocket : public TServerTransport +{ +public: + typedef std::function socket_func_t; + + const static int DEFAULT_BACKLOG = 1024; + + /** + * Constructor. + * + * @param port Port number to bind to + */ + TServerSocket(int port); + + /** + * Constructor. + * + * @param port Port number to bind to + * @param sendTimeout Socket send timeout + * @param recvTimeout Socket receive timeout + */ + TServerSocket(int port, int sendTimeout, int recvTimeout); + + /** + * Constructor. + * + * @param address Address to bind to + * @param port Port number to bind to + */ + TServerSocket(const std::string &address, int port); + + /** + * Constructor used for unix sockets. + * + * @param path Pathname for unix socket. + */ + TServerSocket(const std::string &path); + + ~TServerSocket() override; + + bool isOpen() const override; + + void setSendTimeout(int sendTimeout); + void setRecvTimeout(int recvTimeout); + + void setAcceptTimeout(int accTimeout); + void setAcceptBacklog(int accBacklog); + + void setRetryLimit(int retryLimit); + void setRetryDelay(int retryDelay); + + void setKeepAlive(bool keepAlive) + { + keepAlive_ = keepAlive; + } + + void setTcpSendBuffer(int tcpSendBuffer); + void setTcpRecvBuffer(int tcpRecvBuffer); + + // listenCallback gets called just before listen, and after all Thrift + // setsockopt calls have been made. If you have custom setsockopt + // things that need to happen on the listening socket, this is the place to do it. + void setListenCallback(const socket_func_t &listenCallback) + { + listenCallback_ = listenCallback; + } + + // acceptCallback gets called after each accept call, on the newly created socket. + // It is called after all Thrift setsockopt calls have been made. If you have + // custom setsockopt things that need to happen on the accepted + // socket, this is the place to do it. + void setAcceptCallback(const socket_func_t &acceptCallback) + { + acceptCallback_ = acceptCallback; + } + + // When enabled (the default), new children TSockets will be constructed so + // they can be interrupted by TServerTransport::interruptChildren(). + // This is more expensive in terms of system calls (poll + recv) however + // ensures a connected client cannot interfere with TServer::stop(). + // + // When disabled, TSocket children do not incur an additional poll() call. + // Server-side reads are more efficient, however a client can interfere with + // the server's ability to shutdown properly by staying connected. + // + // Must be called before listen(); mode cannot be switched after that. + // \throws std::logic_error if listen() has been called + void setInterruptableChildren(bool enable); + + THRIFT_SOCKET getSocketFD() override + { + return serverSocket_; + } + + int getPort() const; + + std::string getPath() const; + + bool isUnixDomainSocket() const; + + void listen() override; + void interrupt() override; + void interruptChildren() override; + void close() override; + + protected: + std::shared_ptr acceptImpl() override; + virtual std::shared_ptr createSocket(THRIFT_SOCKET client); + bool interruptableChildren_; + std::shared_ptr pChildInterruptSockReader_; // if interruptableChildren_ this + // is shared with child TSockets + + void _setup_sockopts(); + void _setup_tcp_sockopts(); + +private: + void notify(THRIFT_SOCKET notifySock); + void _setup_unixdomain_sockopts(); + +protected: + int port_; + std::string address_; + std::string path_; + THRIFT_SOCKET serverSocket_; + int acceptBacklog_; + int sendTimeout_; + int recvTimeout_; + int accTimeout_; + int retryLimit_; + int retryDelay_; + int tcpSendBuffer_; + int tcpRecvBuffer_; + bool keepAlive_; + bool listening_; + + concurrency::Mutex rwMutex_; // thread-safe interrupt + THRIFT_SOCKET interruptSockWriter_; // is notified on interrupt() + THRIFT_SOCKET + interruptSockReader_; // is used in select/poll with serverSocket_ for interruptability + THRIFT_SOCKET childInterruptSockWriter_; // is notified on interruptChildren() + + socket_func_t listenCallback_; + socket_func_t acceptCallback_; +}; +} // namespace transport +} // namespace thrift +} // namespace apache + +#endif // #ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ diff --git a/modules/thrift/src/thrift/transport/ThriftTLScertificateType.h b/modules/thrift/src/thrift/transport/ThriftTLScertificateType.h new file mode 100644 index 000000000000..7558ae91cc95 --- /dev/null +++ b/modules/thrift/src/thrift/transport/ThriftTLScertificateType.h @@ -0,0 +1,21 @@ +/* + * Copyright 2022 Young Mei + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef ZEPHYR_MODULES_THRIFT_SRC_THRIFT_TRANSPORT_THRIFTTLSCERTIFICATETYPE_H_ +#define ZEPHYR_MODULES_THRIFT_SRC_THRIFT_TRANSPORT_THRIFTTLSCERTIFICATETYPE_H_ + +namespace apache::thrift::transport +{ + +enum ThriftTLScertificateType { + Thrift_TLS_CA_CERT_TAG, + Thrift_TLS_SERVER_CERT_TAG, + Thrift_TLS_PRIVATE_KEY, +}; + +} // namespace apache::thrift::transport + +#endif /* ZEPHYR_MODULES_THRIFT_SRC_THRIFT_TRANSPORT_THRIFTTLSCERTIFICATETYPE_H_ */ diff --git a/samples/modules/thrift/hello/client/CMakeLists.txt b/samples/modules/thrift/hello/client/CMakeLists.txt new file mode 100644 index 000000000000..44860f78da4b --- /dev/null +++ b/samples/modules/thrift/hello/client/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.20.0) +find_package(Zephyr REQUIRED HINTS $ENV{ZEPHYR_BASE}) +project(thrift_hello_server) + +FILE(GLOB app_sources + src/*.cpp +) + +include(${ZEPHYR_BASE}/modules/thrift/cmake/thrift.cmake) + +set(generated_sources "") +set(gen_dir ${ZEPHYR_BINARY_DIR}/misc/generated/thrift_hello) +list(APPEND generated_sources ${gen_dir}/gen-cpp/hello_types.h) +list(APPEND generated_sources ${gen_dir}/gen-cpp/Hello.cpp) +list(APPEND generated_sources ${gen_dir}/gen-cpp/Hello.h) +list(APPEND app_sources ${generated_sources}) + +thrift( + app + cpp + :no_skeleton + ${gen_dir} + ${ZEPHYR_BASE}/samples/modules/thrift/hello/hello.thrift + "" + ${generated_sources} +) + +target_sources(app PRIVATE ${app_sources}) + +# needed because std::iterator was deprecated with -std=c++17 +target_compile_options(app PRIVATE -Wno-deprecated-declarations) + +# convert .pem files to array data at build time +zephyr_include_directories(${gen_dir}) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/qemu-cert.pem + ${gen_dir}/qemu_cert.pem.inc + ) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/qemu-key.pem + ${gen_dir}/qemu_key.pem.inc + ) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/native-cert.pem + ${gen_dir}/native_cert.pem.inc + ) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/native-key.pem + ${gen_dir}/native_key.pem.inc + ) diff --git a/samples/modules/thrift/hello/client/Kconfig b/samples/modules/thrift/hello/client/Kconfig new file mode 100644 index 000000000000..8a15c50347e2 --- /dev/null +++ b/samples/modules/thrift/hello/client/Kconfig @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2022 Meta + +source "Kconfig.zephyr" + +config THRIFT_COMPACT_PROTOCOL + bool "Use TCompactProtocol in samples" + depends on THRIFT + help + Enable this option to use TCompactProtocol in samples diff --git a/samples/modules/thrift/hello/client/Makefile b/samples/modules/thrift/hello/client/Makefile new file mode 100644 index 000000000000..dabdf64c0ca7 --- /dev/null +++ b/samples/modules/thrift/hello/client/Makefile @@ -0,0 +1,44 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +.PHONY: all clean + +CXXFLAGS := +CXXFLAGS += -std=c++17 + +GEN_DIR = gen-cpp +GENSRC = $(GEN_DIR)/Hello.cpp $(GEN_DIR)/Hello.h $(GEN_DIR)/hello_types.h +GENHDR = $(filter %.h, $(GENSRC)) +GENOBJ = $(filter-out %.h, $(GENSRC:.cpp=.o)) + +THRIFT_FLAGS := +THRIFT_FLAGS += $(shell pkg-config --cflags thrift) +THRIFT_FLAGS += -I$(GEN_DIR) +THRIFT_LIBS = $(shell pkg-config --libs thrift) + +all: hello_client hello_client_compact hello_client_ssl hello_client_py.stamp + +hello_client.stamp: ../hello.thrift + thrift --gen cpp:no_skeleton $< + +$(GENSRC): hello_client.stamp + touch $@ + +%.o: %.cpp $(GENHDR) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ -c $< + +hello_client: src/main.cpp $(GENOBJ) $(GENHDR) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ $< $(GENOBJ) $(THRIFT_LIBS) + +hello_client_compact: src/main.cpp $(GENOBJ) $(GENHDR) + $(CXX) -DCONFIG_THRIFT_COMPACT_PROTOCOL=1 $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ $< $(GENOBJ) $(THRIFT_LIBS) + +hello_client_ssl: src/main.cpp $(GENOBJ) $(GENHDR) + $(CXX) -DCONFIG_THRIFT_SSL_SOCKET=1 $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ $< $(GENOBJ) $(THRIFT_LIBS) + +hello_client_py.stamp: ../hello.thrift + thrift --gen py $< + touch $@ + +clean: + rm -Rf hello_client hello_client_compact hello_client_ssl $(GEN_DIR) gen-py *.stamp diff --git a/samples/modules/thrift/hello/client/hello_client.py b/samples/modules/thrift/hello/client/hello_client.py new file mode 100755 index 000000000000..06422347ff03 --- /dev/null +++ b/samples/modules/thrift/hello/client/hello_client.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023, Meta +# +# SPDX-License-Identifier: Apache-2.0 + +"""Thrift Hello Client Sample + +Connect to a hello service and demonstrate the +ping(), echo(), and counter() Thrift RPC methods. + +Usage: + ./hello_client.py [ip] +""" + +import argparse +import sys +sys.path.append('gen-py') + +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport +from thrift.transport import TSocket +from hello import Hello + + +def parse_args(): + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument('--ip', default='192.0.2.1', + help='IP address of hello server') + + return parser.parse_args() + + +def main(): + args = parse_args() + + transport = TSocket.TSocket(args.ip, 4242) + transport = TTransport.TBufferedTransport(transport) + protocol = TBinaryProtocol.TBinaryProtocol(transport) + client = Hello.Client(protocol) + + transport.open() + + client.ping() + client.echo('Hello, world!') + + # necessary to mitigate unused variable warning with for i in range(5) + i = 0 + while i < 5: + client.counter() + i = i + 1 + + transport.close() + + +if __name__ == '__main__': + main() diff --git a/samples/modules/thrift/hello/client/prj.conf b/samples/modules/thrift/hello/client/prj.conf new file mode 100644 index 000000000000..83026931ad21 --- /dev/null +++ b/samples/modules/thrift/hello/client/prj.conf @@ -0,0 +1,72 @@ +# CONFIG_LIB_CPLUSPLUS Dependencies +CONFIG_NEWLIB_LIBC=y +CONFIG_NEWLIB_LIBC_NANO=n + +# CONFIG_THRIFT Dependencies +CONFIG_CPP=y +CONFIG_STD_CPP17=y +CONFIG_CPP_EXCEPTIONS=y +CONFIG_EXTERNAL_LIBCPP=y +CONFIG_POSIX_API=y +CONFIG_NETWORKING=y +CONFIG_NET_SOCKETS=y +CONFIG_NET_SOCKETPAIR=y +CONFIG_HEAP_MEM_POOL_SIZE=16384 +CONFIG_EVENTFD=y + +CONFIG_THRIFT=y + +CONFIG_TEST_RANDOM_GENERATOR=y +# pthread_cond_wait() triggers sentinel for some reason +CONFIG_STACK_SENTINEL=n + +# Generic networking options +CONFIG_NETWORKING=y +CONFIG_NET_UDP=y +CONFIG_NET_TCP=y +CONFIG_NET_IPV6=n +CONFIG_NET_IPV4=y +CONFIG_NET_SOCKETS=y +CONFIG_POSIX_MAX_FDS=6 +CONFIG_NET_CONNECTION_MANAGER=y + +# Kernel options +CONFIG_ENTROPY_GENERATOR=y +CONFIG_TEST_RANDOM_GENERATOR=y +CONFIG_INIT_STACKS=y + +# Logging +CONFIG_NET_LOG=y +CONFIG_LOG=y +CONFIG_NET_STATISTICS=y +CONFIG_PRINTK=y + +# Network buffers +CONFIG_NET_PKT_RX_COUNT=16 +CONFIG_NET_PKT_TX_COUNT=16 +CONFIG_NET_BUF_RX_COUNT=64 +CONFIG_NET_BUF_TX_COUNT=64 +CONFIG_NET_CONTEXT_NET_PKT_POOL=y + +# IP address options +CONFIG_NET_MAX_CONTEXTS=10 + +# Network application options and configuration +CONFIG_NET_CONFIG_SETTINGS=y +CONFIG_NET_CONFIG_NEED_IPV6=n +CONFIG_NET_CONFIG_NEED_IPV4=y +CONFIG_NET_CONFIG_MY_IPV4_ADDR="192.0.2.1" +CONFIG_NET_CONFIG_PEER_IPV4_ADDR="192.0.2.2" + +# Number of socket descriptors might need adjusting +# if there are more than 1 handlers defined. +CONFIG_POSIX_MAX_FDS=16 + +# Some platforms require relatively large stack sizes. +# This can be tuned per-board. +CONFIG_MAIN_STACK_SIZE=8192 +CONFIG_SYSTEM_WORKQUEUE_STACK_SIZE=8192 +CONFIG_NET_TCP_WORKQ_STACK_SIZE=4096 +CONFIG_NET_MGMT_EVENT_STACK_SIZE=4096 +CONFIG_IDLE_STACK_SIZE=4096 +CONFIG_NET_RX_STACK_SIZE=8192 diff --git a/samples/modules/thrift/hello/client/sample.yaml b/samples/modules/thrift/hello/client/sample.yaml new file mode 100644 index 000000000000..abe77def0179 --- /dev/null +++ b/samples/modules/thrift/hello/client/sample.yaml @@ -0,0 +1,16 @@ +sample: + description: Hello Thrift client sample + name: hello thrift client +common: + tags: thrift cpp sample + build_only: true + modules: + - thrift + platform_allow: mps2_an385 qemu_x86_64 +tests: + sample.thrift.hello.server.binaryProtocol: {} + sample.thrift.hello.server.compactProtocol: + extra_configs: + - CONFIG_THRIFT_COMPACT_PROTOCOL=y + sample.thrift.hello.server.tlsTransport: + extra_args: OVERLAY_CONFIG="../overlay-tls.conf" diff --git a/samples/modules/thrift/hello/client/src/main.cpp b/samples/modules/thrift/hello/client/src/main.cpp new file mode 100644 index 000000000000..bc79c6fbf403 --- /dev/null +++ b/samples/modules/thrift/hello/client/src/main.cpp @@ -0,0 +1,130 @@ +/* + * Copyright 2022 Young Mei + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef __ZEPHYR__ +#include +#endif + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "Hello.h" + +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; + +#ifndef IS_ENABLED +#define IS_ENABLED(flag) flag +#endif + +#ifndef CONFIG_THRIFT_COMPACT_PROTOCOL +#define CONFIG_THRIFT_COMPACT_PROTOCOL 0 +#endif + +#ifndef CONFIG_THRIFT_SSL_SOCKET +#define CONFIG_THRIFT_SSL_SOCKET 0 +#endif + +#ifdef __ZEPHYR__ +int main(void) +#else +int main(int argc, char **argv) +#endif +{ + std::string my_addr; + +#ifdef __ZEPHYR__ + my_addr = CONFIG_NET_CONFIG_PEER_IPV4_ADDR; +#else + if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) { + if (argc != 5) { + printf("usage: %s " + "\n", + argv[0]); + return EXIT_FAILURE; + } + } + + if (argc >= 2) { + my_addr = std::string(argv[1]); + } else { + my_addr = "192.0.2.1"; + } +#endif + + int port = 4242; + std::shared_ptr protocol; + std::shared_ptr transport; + std::shared_ptr socketFactory; + std::shared_ptr trans; + + if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) { + const int port = 4242; + socketFactory = std::make_shared(); + socketFactory->authenticate(true); + +#ifdef __ZEPHYR__ + static const char qemu_cert_pem[] = { +#include "qemu_cert.pem.inc" + }; + + static const char qemu_key_pem[] = { +#include "qemu_key.pem.inc" + }; + + static const char native_cert_pem[] = { +#include "native_cert.pem.inc" + }; + + socketFactory->loadCertificateFromBuffer(qemu_cert_pem); + socketFactory->loadPrivateKeyFromBuffer(qemu_key_pem); + socketFactory->loadTrustedCertificatesFromBuffer(native_cert_pem); +#else + socketFactory->loadCertificate(argv[2]); + socketFactory->loadPrivateKey(argv[3]); + socketFactory->loadTrustedCertificates(argv[4]); +#endif + trans = socketFactory->createSocket(my_addr, port); + } else { + trans = std::make_shared(my_addr, port); + } + + transport = std::make_shared(trans); + + if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) { + protocol = std::make_shared(transport); + } else { + protocol = std::make_shared(transport); + } + + HelloClient client(protocol); + + try { + transport->open(); + client.ping(); + std::string s; + client.echo(s, "Hello, world!"); + for (int i = 0; i < 5; ++i) { + client.counter(); + } + + transport->close(); + } catch (std::exception &e) { + printf("caught exception: %s\n", e.what()); + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} diff --git a/samples/modules/thrift/hello/doc/index.rst b/samples/modules/thrift/hello/doc/index.rst new file mode 100644 index 000000000000..84a784e02189 --- /dev/null +++ b/samples/modules/thrift/hello/doc/index.rst @@ -0,0 +1,165 @@ +.. _thrift-hello-sample: + +Hello Sample Application +######################## + +Overview +******** + +This sample application includes a client a server implementing the RPC +interface described in ``thrift/hello.thrift``. The purpose of this +example is to demonstrate how components at different layer in thrift can +be combined to build an application with desired features. + + +Requirements +************ + +- QEMU Networking (described in :ref:`networking_with_qemu`) +- Thrift dependencies installed for your host OS e.g. in Ubuntu + +.. code-block:: console + :caption: Install additional dependencies in Ubuntu + + sudo apt install -y libboost-all-dev thrift-compiler libthrift-dev + +Building and Running +******************** + +This application can be run on a Linux host, with either the server or the +client in the QEMU environment, and the peer is built and run natively on +the host. + +Building the Native Client and Server +===================================== + +.. code-block:: console + + $ make -j -C client/ + $ make -j -C server/ + +Under ``client/``, 3 executables will be generated, and components +used in each layer of them are listed below: + ++----------------------+------------+--------------------+------------------+ +| hello_client | TSocket | TBufferedTransport | TBinaryProtocol | ++----------------------+------------+--------------------+------------------+ +| hello_client_compact | TSocket | TBufferedTransport | TCompactProtocol | ++----------------------+------------+--------------------+------------------+ +| hello_client_ssl | TSSLSocket | TBufferedTransport | TBinaryProtocol | ++----------------------+------------+--------------------+------------------+ + +The same applies for the server. Only the client and the server with the +same set of stacks can communicate. + +Additionally, there is a ``hello_client.py`` Python script that can be used +interchangeably with the ``hello_client`` C++ application to illustrate the +cross-language capabilities of Thrift. + ++----------------------+------------+--------------------+------------------+ +| hello_client.py | TSocket | TBufferedTransport | TBinaryProtocol | ++----------------------+------------+--------------------+------------------+ + +Running the Zephyr Server in QEMU +================================= + +Build the Zephyr version of the ``hello/server`` sample application like this: + +.. zephyr-app-commands:: + :zephyr-app: samples/modules/thrift/hello/server + :board: board_name + :goals: build + :compact: + +To enable advanced features, extra arguments should be passed accordingly: + +- TCompactProtocol: ``-DCONFIG_THRIFT_COMPACT_PROTOCOL=y`` +- TSSLSocket: ``-DCONF_FILE="prj.conf ../overlay-tls.conf"`` + +For example, to build for ``qemu_x86_64`` with TSSLSocket support: + +.. zephyr-app-commands:: + :zephyr-app: samples/modules/thrift/hello/server + :host-os: unix + :board: qemu_x86_64 + :conf: "prj.conf ../overlay-tls.conf" + :goals: run + :compact: + +In another terminal, run the ``hello_client`` sample app compiled for the +host OS: + +.. code-block:: console + + $ ./hello_client 192.0.2.1 + $ ./hello_client_compact 192.0.2.1 + $ ./hello_client_ssl 192.0.2.1 ../native-cert.pem ../native-key.pem ../qemu-cert.pem + +You should observe the following in the original ``hello/server`` terminal: + +.. code-block:: console + + ping + echo: Hello, world! + counter: 1 + counter: 2 + counter: 3 + counter: 4 + counter: 5 + +In the client terminal, run ``hello_client.py`` app under the host OS (not +described for compact or ssl variants for brevity): + +.. code-block:: console + + $ ./hello_client.py + +You should observe the following in the original ``hello/server`` terminal. +Note that the server's state is not discarded (the counter continues to +increase). + +.. code-block:: console + + ping + echo: Hello, world! + counter: 6 + counter: 7 + counter: 8 + counter: 9 + counter: 10 + +Running the Zephyr Client in QEMU +================================= + +In another terminal, run the ``hello_server`` sample app compiled for the +host OS: + +.. code-block:: console + + $ ./hello_server 0.0.0.0 + $ ./hello_server_compact 0.0.0.0 + $ ./hello_server_ssl 0.0.0.0 ../native-cert.pem ../native-key.pem ../qemu-cert.pem + + +Then, in annother terminal, run the corresponding ``hello/client`` sample: + +.. zephyr-app-commands:: + :zephyr-app: samples/modules/thrift/hello/client + :board: qemu_x86_64 + :goals: run + :compact: + +The additional arguments for advanced features are the same as +``hello/server``. + +You should observe the following in the original ``hello_server`` terminal: + +.. code-block:: console + + ping + echo: Hello, world! + counter: 1 + counter: 2 + counter: 3 + counter: 4 + counter: 5 diff --git a/samples/modules/thrift/hello/hello.thrift b/samples/modules/thrift/hello/hello.thrift new file mode 100644 index 000000000000..ba094466d107 --- /dev/null +++ b/samples/modules/thrift/hello/hello.thrift @@ -0,0 +1,11 @@ +/* + * Copyright 2022 Meta + * + * SPDX-License-Identifier: Apache-2.0 + */ + +service Hello { + void ping(); + string echo(1: string msg); + i32 counter(); +} diff --git a/samples/modules/thrift/hello/native-cert.pem b/samples/modules/thrift/hello/native-cert.pem new file mode 100644 index 000000000000..3adf82b166c0 --- /dev/null +++ b/samples/modules/thrift/hello/native-cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUaGlOEAH7p7FyprSNEHTH3Yy3+8owDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJMTkyLjAuMi4yMB4X +DTIyMDgwOTA3NDgzOVoXDTQ5MTIyNTA3NDgzOVowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJMTkyLjAuMi4yMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA4cg88HTUwqufnmbCfsZ/j+lkuRLMRWVuCyZ6d4AapIwHanHhdAhj +jejPbEnSMacxiYUVhuKMjMWw1sUMKTwu9MA2wk6jCvewEhnCymaZY0IHE1SCY8/B +Zik70ds/0D8OtUB53xgR5el/ntUUJgmczW5ZIKmpcW86OR0rzUs9j6I8MF3Tp4qK +PSFBuWMC+nXkPuX0l631dtk76DZabOqST3Hiqi4dDV+TaOO5KmN2k6E/iw7p7X4F +ddt9JJIlErV9DZqaKvymdQXlhymohBQHPE909J9z0HEkVqBHRgiMvUyo2hYOqLya +qe+eWvJIEvF/bpEHcESm1v+txagk+BulyQIDAQABo1MwUTAdBgNVHQ4EFgQU7PKo +jr9k+KBC/hmnAUKqsfpgZSowHwYDVR0jBBgwFoAU7PKojr9k+KBC/hmnAUKqsfpg +ZSowDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAVwUrUMLZbrAi +VAS72zpvedGN912bkzgP3vsDbW/F4jbwq+NcVFoqtgUiJrjbhPa52qig6bzNtllX +fVIAjh/pCB+PYDDXotO3mk79Sofy3W2qNrlFQNJsqJpDkY/yF2Dg8FdKao6oPwSs +ldb6REFy6V7I9pdy4Emq+ObW6btzEBByky5TrwUf44ZwSuhxKB2R+jqZHM7BNsnH +QnGUia/qGF1plSUqFzsdq9AwQ6H9v2SLwPDOqEGLDV1Jvwoe0svp+SaaBxY1aGHD +Mg5Z6Uh5cHYlJLCC7WsfTlH+9HZ9ALg0Gww1twlYXyaMw4R480YtSPZC4Wv0jy35 +Sbw6IAbCTg== +-----END CERTIFICATE----- diff --git a/samples/modules/thrift/hello/native-key.pem b/samples/modules/thrift/hello/native-key.pem new file mode 100644 index 000000000000..b6c67e03b844 --- /dev/null +++ b/samples/modules/thrift/hello/native-key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDhyDzwdNTCq5+e +ZsJ+xn+P6WS5EsxFZW4LJnp3gBqkjAdqceF0CGON6M9sSdIxpzGJhRWG4oyMxbDW +xQwpPC70wDbCTqMK97ASGcLKZpljQgcTVIJjz8FmKTvR2z/QPw61QHnfGBHl6X+e +1RQmCZzNblkgqalxbzo5HSvNSz2PojwwXdOnioo9IUG5YwL6deQ+5fSXrfV22Tvo +Nlps6pJPceKqLh0NX5No47kqY3aToT+LDuntfgV1230kkiUStX0Nmpoq/KZ1BeWH +KaiEFAc8T3T0n3PQcSRWoEdGCIy9TKjaFg6ovJqp755a8kgS8X9ukQdwRKbW/63F +qCT4G6XJAgMBAAECggEAXbZQQNelJWW5mzP4n0kBUjijs0NvoJAgdCVU6Hu10z1B +qLc6xf/jXlfWnBIp2a0VHQitbi5i+tzk8MeZrBXMQY70S4L7Hka/AExL8tlR6gZS +TH4jnozxL1eG+iv/2Q4LK0TnMKdbamuXqlOziLQtroCSIsH4z9nEN0d50jxcAVyj +/WUFdWztYbrZFT2m16zQcic2/GyGGDTDkcmpu3+FGhbCDz22W/3iXS57Vhtd9Ety +9VnxA0cDC7O2X3GuQyCaCSmepC+hdZlJTiW16jto1IPYXm38cVn8WETcRpazmtfW +0wtMSFinUZpGpbEBPEvhIVE5Yaor+07qzGIuEwKrXwKBgQDx7b6bzxqk6TAt7oH/ +ub6X4u1tarfM9e/T8R4CfEo9sLvZjaiwKIclCueTjdRUJOogMoZCcZRdD0+NCg5h +JR4FmCwpSB1iaiRDYmPJr7hwlqmwah/t90DKdPwY9uQn//TlWoUwDWtBZ7TwB/HN +LTfyO2tTMUeNCp8VYV8z8Fe0dwKBgQDu6hN9opKpqyTm8Aqwiy+LbKCQg9YwI3zV +3B7KP1buaa4WHiWTWd86Ns6wwFRQSr3jAexNQaDvH8MhI696iBaEVLAvYFmGomLR +V3dpudap143AV0wwvniUn78ewyLPg7V5fUcetcX8QlKLp0wGWuYAZqgJ3uA2Kf3v +qgIi7BlHvwKBgCKNYfu+yH9lDoyA0/BCBwaKUn6eD0ImneoXNcIFHlVROIMJyF3g +a+zOceSRDRI3c3jFvoce0aG43hO2q/cT5gXGhggfVJMJtcQp+TaE8kKiQfoALi8+ +cPJ5Ysft+wf7dm6LTxpd0EO3HBBsEgzLuIHQGrP3BdEPA0l6bq5sVRphAoGBAJI8 +MGXsBn1XxiSctM5Ow3FBsh4CtC2O6zAzpZ0BnAIeKXJcTX+duOb2+RhzAKiMtyGl +4a+ABjOXa2ZzY0tK1Q12kMjO3r1r07RzJyJNn7khuSALzxTe4QuHpAH+SuZdpcyR +A+EmPeMj7UaRxhT1umZwb1ZrVy2QEmCJ3PjnLqodAoGBALVLa4X2CHHEACculEcC +wDUdaF8penGVV0HP+4DvxLUtcvPohvdyvcoTzqMqkk2QCKRfnImD3aL014Pgv26C +9kO1C8/K+cvwq2Pc6v2fgTHwSqj1DglNEJOctbStpwTDugG+vsymlTaD6xogFunW +SmtskQUnB2GXxqXMzmMIGUYC +-----END PRIVATE KEY----- diff --git a/samples/modules/thrift/hello/overlay-tls.conf b/samples/modules/thrift/hello/overlay-tls.conf new file mode 100644 index 000000000000..3524f07558ad --- /dev/null +++ b/samples/modules/thrift/hello/overlay-tls.conf @@ -0,0 +1,10 @@ +CONFIG_THRIFT_SSL_SOCKET=y + +# TLS configuration +CONFIG_MBEDTLS=y +CONFIG_MBEDTLS_PEM_CERTIFICATE_FORMAT=y +CONFIG_MBEDTLS_ENABLE_HEAP=y +CONFIG_MBEDTLS_HEAP_SIZE=60000 +CONFIG_MBEDTLS_SSL_MAX_CONTENT_LEN=2048 +CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS=6 +CONFIG_NET_SOCKETS_SOCKOPT_TLS=y diff --git a/samples/modules/thrift/hello/qemu-cert.pem b/samples/modules/thrift/hello/qemu-cert.pem new file mode 100644 index 000000000000..d89973ebb7a9 --- /dev/null +++ b/samples/modules/thrift/hello/qemu-cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUImw6a5K1bBeocIiUyvyhQLUWKBYwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJMTkyLjAuMi4xMB4X +DTIyMDgwOTA3NDgwMVoXDTQ5MTIyNTA3NDgwMVowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJMTkyLjAuMi4xMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAyXtaqbldQrHL61NcMDhR3MTv8cVCFnsVZEgq621vZ5njCnK557LR +gOKJx6Mn+8d0au1RgjIjyhuW0aGRQnnkX8mBXSqFJ+jQYTBAs4i5Jemn+Rsf17Lj +R11eGNjqItS5JqlrKUfY3CfgFoN/YJITmDgIE0d6NbJbq+LkuBuvdf4+bxDp8w/m +gOr2wvGcAGvtaAWsJVaQvXkLlXKwAOR4/ChFLQJexe4KreMwnvaa/0JXRia+4I+U +PWSJUmqcpwebj8oDgXqiQv+6JXF66ZoULVINiiDk9qtAbksm8Vz/QQ8Ll4ML96fZ +ZXM3sie0OSsEdpzfLayYwc3sX7Y8p1f4SQIDAQABo1MwUTAdBgNVHQ4EFgQU5LRU +s0gUR8AEWND35ZbE3QilH5wwHwYDVR0jBBgwFoAU5LRUs0gUR8AEWND35ZbE3Qil +H5wwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAKQm2tMYeFUqb +ep1Teq+zcsjRf3Zt7CUQoYUKwRbxfG+M8f0KtE0XgJ5jntRHMfiw157qGJY2YRPY +DJ1G8TmF//x4+cudixuZlkD0t7J+YQBLjNtvu68FIb+gNuv/GJSBps1C4Q49P8/c +48ulUCvEUjUwJ99yC/v1BajdUdunqgfwVnD3i5cAm1dStLMxeuUMhU7OaMnylkMY +eF+CXzZ7k2Nmr9SlGZR28kiDJF6TdfkVTxL8A5Xyjh/DvRxVYYE0LzmG99Q1qj5S +Q1U4wCHJDm7/sb7GfXrl6abqECWPWi65/77QSsq8LkRkbbpkCvm9MyxKPSMqqUeF +vH7+Xv7oMw== +-----END CERTIFICATE----- diff --git a/samples/modules/thrift/hello/qemu-key.pem b/samples/modules/thrift/hello/qemu-key.pem new file mode 100644 index 000000000000..d48b0160423b --- /dev/null +++ b/samples/modules/thrift/hello/qemu-key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDJe1qpuV1Cscvr +U1wwOFHcxO/xxUIWexVkSCrrbW9nmeMKcrnnstGA4onHoyf7x3Rq7VGCMiPKG5bR +oZFCeeRfyYFdKoUn6NBhMECziLkl6af5Gx/XsuNHXV4Y2Ooi1LkmqWspR9jcJ+AW +g39gkhOYOAgTR3o1slur4uS4G691/j5vEOnzD+aA6vbC8ZwAa+1oBawlVpC9eQuV +crAA5Hj8KEUtAl7F7gqt4zCe9pr/QldGJr7gj5Q9ZIlSapynB5uPygOBeqJC/7ol +cXrpmhQtUg2KIOT2q0BuSybxXP9BDwuXgwv3p9llczeyJ7Q5KwR2nN8trJjBzexf +tjynV/hJAgMBAAECggEACB/IddgILeiH6bGsjF2KVX3TvrAJcb5y9skfK+uM0VoH +6N5b5ym+H2A4azDwa4tXwyOrGfYbcdU/dVf6Xg/BWOhskN0Q+J0/YRqGPVxZujDh +VEmBn5uyZOiw+Devkd8ke8Mwk5NL/Vfr8KnK5idgZrlPAAytiEKXG4efA9p1Q9YT +N+vBh+S36uP74Y1S/gpZaVo4igd03htqZ87s+h32l4rWM/KznbUCza4tl/Kz4oaZ +SddfQ1IcU4xbU5W+r405XEQcK6UgF5+wOTEyuh5ECVblQFnvLrQELinokj2/Xbvd +zx214mxYxjff9tmt4gZcWHFAaHQfV2e28sGdQzglgQKBgQDmrruc1IOBKSfW5Xw7 +O60zpPa0gCbweh21+amAan0V0u4NLXcEp6BhDdO00Fkuw8hNNLRlGxMtIvLX/RuB +ZKstE/PPyq9O1vJBv3KoTg4uF71ic1LRb77fZrrl9xuuo6beS8agodPqlsxmREa7 +sSfP2Z1t2la8loxz4TefiWW2iQKBgQDfmDJ6uZCB1GKaSYjtIRDldxyjIUZvFbm4 +vQGQS/xJvSmMdCBbBfenHshgCDoPvCslB4aSPzi0ihUwdoWpQ+dVoNJTONxJzJwV +wEAC3GvLQYMTpEGzht8u2ullde76G33GoFSDSF9ErCZUALosjrzlHTwRjYnBo4Mu +pV4eUufDwQKBgCItlWKBIhLK9DokuilUiC70rBDGQ/6xOSGzIegC3xGStO6C4/Vu +mJaIo+tQS0Zgf5bgzjGEt2yilvRlbePX9HyzThZlY1/8/Nu879H77qHppoelqomZ +UuBqqhpUaGeRm7Gn7H/0Oh+xxAsK5qf8cXecOHUEOoGqlJi+r60VgFpxAoGABria +e9nsIBr0Q9MGDKq7yUoFUFoFtf0fMhBsZZwDH2xSPWiYOGQ7h4iDWW+l3yc23MwX +HXpNCBBGhshpSCdEYuyMpffFl2pRHs5CnlNl4hw8BnEfkHfzaYMnFOewoVAGPdw/ +7hpU0smh9VB4SDKaNwDj91sb0vhJTzOlWp//W4ECgYEAt24lSRSlKxh7RzSiXwnu +hcd92nw2fVPwIWMy/X3UY8GD1V9O1qXcBIS2IMoWDChr6yPhfZ5GzWMhAb170tDC +5Xyj6v5b3wkADEFbkd7ioBTNi34og6w16wm1MZgiVRHcAGuOarGH5MzRZQBAEmHQ +CJiEUbFAwcg3RFElByRzIx0= +-----END PRIVATE KEY----- diff --git a/samples/modules/thrift/hello/server/CMakeLists.txt b/samples/modules/thrift/hello/server/CMakeLists.txt new file mode 100644 index 000000000000..44860f78da4b --- /dev/null +++ b/samples/modules/thrift/hello/server/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.20.0) +find_package(Zephyr REQUIRED HINTS $ENV{ZEPHYR_BASE}) +project(thrift_hello_server) + +FILE(GLOB app_sources + src/*.cpp +) + +include(${ZEPHYR_BASE}/modules/thrift/cmake/thrift.cmake) + +set(generated_sources "") +set(gen_dir ${ZEPHYR_BINARY_DIR}/misc/generated/thrift_hello) +list(APPEND generated_sources ${gen_dir}/gen-cpp/hello_types.h) +list(APPEND generated_sources ${gen_dir}/gen-cpp/Hello.cpp) +list(APPEND generated_sources ${gen_dir}/gen-cpp/Hello.h) +list(APPEND app_sources ${generated_sources}) + +thrift( + app + cpp + :no_skeleton + ${gen_dir} + ${ZEPHYR_BASE}/samples/modules/thrift/hello/hello.thrift + "" + ${generated_sources} +) + +target_sources(app PRIVATE ${app_sources}) + +# needed because std::iterator was deprecated with -std=c++17 +target_compile_options(app PRIVATE -Wno-deprecated-declarations) + +# convert .pem files to array data at build time +zephyr_include_directories(${gen_dir}) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/qemu-cert.pem + ${gen_dir}/qemu_cert.pem.inc + ) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/qemu-key.pem + ${gen_dir}/qemu_key.pem.inc + ) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/native-cert.pem + ${gen_dir}/native_cert.pem.inc + ) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/native-key.pem + ${gen_dir}/native_key.pem.inc + ) diff --git a/samples/modules/thrift/hello/server/Kconfig b/samples/modules/thrift/hello/server/Kconfig new file mode 100644 index 000000000000..8a15c50347e2 --- /dev/null +++ b/samples/modules/thrift/hello/server/Kconfig @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2022 Meta + +source "Kconfig.zephyr" + +config THRIFT_COMPACT_PROTOCOL + bool "Use TCompactProtocol in samples" + depends on THRIFT + help + Enable this option to use TCompactProtocol in samples diff --git a/samples/modules/thrift/hello/server/Makefile b/samples/modules/thrift/hello/server/Makefile new file mode 100644 index 000000000000..20fc9bf625c3 --- /dev/null +++ b/samples/modules/thrift/hello/server/Makefile @@ -0,0 +1,40 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +.PHONY: all clean + +CXXFLAGS := +CXXFLAGS += -std=c++17 + +GEN_DIR = gen-cpp +GENSRC = $(GEN_DIR)/Hello.cpp $(GEN_DIR)/Hello.h $(GEN_DIR)/hello_types.h +GENHDR = $(filter %.h, $(GENSRC)) +GENOBJ = $(filter-out %.h, $(GENSRC:.cpp=.o)) + +THRIFT_FLAGS := +THRIFT_FLAGS += $(shell pkg-config --cflags thrift) +THRIFT_FLAGS += -I$(GEN_DIR) +THRIFT_LIBS := +THRIFT_LIBS = $(shell pkg-config --libs thrift) + +all: hello_server hello_server_compact hello_server_ssl + +hello_server.stamp: ../hello.thrift + thrift --gen cpp:no_skeleton $< + +$(GENSRC): hello_server.stamp + +%.o: %.cpp $(GENHDR) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ -c $< + +hello_server: src/main.cpp $(GENOBJ) $(GENHDR) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ $< $(GENOBJ) $(THRIFT_LIBS) + +hello_server_compact: src/main.cpp $(GENOBJ) $(GENHDR) + $(CXX) -DCONFIG_THRIFT_COMPACT_PROTOCOL=1 $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ $< $(GENOBJ) $(THRIFT_LIBS) + +hello_server_ssl: src/main.cpp $(GENOBJ) $(GENHDR) + $(CXX) -DCONFIG_THRIFT_SSL_SOCKET=1 $(CPPFLAGS) $(CXXFLAGS) $(THRIFT_FLAGS) -o $@ $< $(GENOBJ) $(THRIFT_LIBS) + +clean: + rm -Rf hello_server hello_server_compact hello_server_ssl $(GEN_DIR) diff --git a/samples/modules/thrift/hello/server/prj.conf b/samples/modules/thrift/hello/server/prj.conf new file mode 100644 index 000000000000..20f6631202a0 --- /dev/null +++ b/samples/modules/thrift/hello/server/prj.conf @@ -0,0 +1,65 @@ +# Need a full libc++ +CONFIG_NEWLIB_LIBC=y +CONFIG_NEWLIB_LIBC_NANO=n + +# CONFIG_THRIFT Dependencies +CONFIG_CPP=y +CONFIG_STD_CPP17=y +CONFIG_CPP_EXCEPTIONS=y +CONFIG_EXTERNAL_LIBCPP=y +CONFIG_POSIX_API=y +CONFIG_NET_SOCKETPAIR=y +CONFIG_HEAP_MEM_POOL_SIZE=16384 +CONFIG_EVENTFD=y + +CONFIG_THRIFT=y + +# Generic networking options +CONFIG_NETWORKING=y +CONFIG_NET_UDP=y +CONFIG_NET_TCP=y +CONFIG_NET_IPV6=n +CONFIG_NET_IPV4=y +CONFIG_NET_SOCKETS=y +CONFIG_NET_CONNECTION_MANAGER=y + +# Kernel options +CONFIG_ENTROPY_GENERATOR=y +CONFIG_TEST_RANDOM_GENERATOR=y +CONFIG_INIT_STACKS=y + +# Logging +CONFIG_NET_LOG=y +CONFIG_LOG=y +CONFIG_NET_STATISTICS=y +CONFIG_PRINTK=y + +# Network buffers +CONFIG_NET_PKT_RX_COUNT=16 +CONFIG_NET_PKT_TX_COUNT=16 +CONFIG_NET_BUF_RX_COUNT=64 +CONFIG_NET_BUF_TX_COUNT=64 +CONFIG_NET_CONTEXT_NET_PKT_POOL=y + +# IP address options +CONFIG_NET_MAX_CONTEXTS=10 + +# Network application options and configuration +CONFIG_NET_CONFIG_SETTINGS=y +CONFIG_NET_CONFIG_NEED_IPV6=n +CONFIG_NET_CONFIG_NEED_IPV4=y +CONFIG_NET_CONFIG_MY_IPV4_ADDR="192.0.2.1" +CONFIG_NET_CONFIG_PEER_IPV4_ADDR="192.0.2.2" + +# Number of socket descriptors might need adjusting +# if there are more than 1 handlers defined. +CONFIG_POSIX_MAX_FDS=16 + +# Some platforms require relatively large stack sizes. +# This can be tuned per-board. +CONFIG_MAIN_STACK_SIZE=8192 +CONFIG_SYSTEM_WORKQUEUE_STACK_SIZE=8192 +CONFIG_NET_TCP_WORKQ_STACK_SIZE=4096 +CONFIG_NET_MGMT_EVENT_STACK_SIZE=4096 +CONFIG_IDLE_STACK_SIZE=4096 +CONFIG_NET_RX_STACK_SIZE=8192 diff --git a/samples/modules/thrift/hello/server/sample.yaml b/samples/modules/thrift/hello/server/sample.yaml new file mode 100644 index 000000000000..c56bee255607 --- /dev/null +++ b/samples/modules/thrift/hello/server/sample.yaml @@ -0,0 +1,16 @@ +sample: + description: Hello Thrift server sample + name: hello thrift server +common: + tags: thrift cpp sample + build_only: true + modules: + - thrift + platform_allow: mps2_an385 qemu_x86_64 +tests: + sample.thrift.hello.server.binaryProtocol: {} + sample.thrift.hello.server.compactProtocol: + extra_configs: + - CONFIG_THRIFT_COMPACT_PROTOCOL=y + sample.thrift.hello.server.tlsTransport: + extra_args: OVERLAY_CONFIG="../overlay-tls.conf" diff --git a/samples/modules/thrift/hello/server/src/HelloHandler.h b/samples/modules/thrift/hello/server/src/HelloHandler.h new file mode 100644 index 000000000000..24c05f668f2f --- /dev/null +++ b/samples/modules/thrift/hello/server/src/HelloHandler.h @@ -0,0 +1,46 @@ +/* + * Copyright 2022 Meta + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifdef __ZEPHYR__ +#include +#else +#define printk printf +#endif + +#include + +#include "Hello.h" + +class HelloHandler : virtual public HelloIf +{ +public: + HelloHandler() : count(0) + { + } + + void ping() + { + printk("%s\n", __func__); + } + + void echo(std::string &_return, const std::string &msg) + { + printk("%s: %s\n", __func__, msg.c_str()); + _return = msg; + } + + int32_t counter() + { + ++count; + printk("%s: %d\n", __func__, count); + return count; + } + +protected: + int count; +}; diff --git a/samples/modules/thrift/hello/server/src/main.cpp b/samples/modules/thrift/hello/server/src/main.cpp new file mode 100644 index 000000000000..f81db2ad366b --- /dev/null +++ b/samples/modules/thrift/hello/server/src/main.cpp @@ -0,0 +1,124 @@ +/* + * Copyright 2022 Young Mei + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef __ZEPHYR__ +#include +#endif + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "Hello.h" +#include "HelloHandler.h" + +using namespace ::apache::thrift; +using namespace ::apache::thrift::protocol; +using namespace ::apache::thrift::transport; +using namespace ::apache::thrift::server; + +#ifndef IS_ENABLED +#define IS_ENABLED(flag) flag +#endif + +#ifndef CONFIG_THRIFT_COMPACT_PROTOCOL +#define CONFIG_THRIFT_COMPACT_PROTOCOL 0 +#endif + +#ifndef CONFIG_THRIFT_SSL_SOCKET +#define CONFIG_THRIFT_SSL_SOCKET 0 +#endif + +#ifdef __ZEPHYR__ +int main(void) +#else +int main(int argc, char **argv) +#endif +{ + std::string my_addr; + +#ifdef __ZEPHYR__ + my_addr = CONFIG_NET_CONFIG_MY_IPV4_ADDR; +#else + if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) { + if (argc != 5) { + printf("usage: %s " + "\n", + argv[0]); + return EXIT_FAILURE; + } + } else { + if (argc != 2) { + printf("usage: %s \n", argv[0]); + return EXIT_FAILURE; + } + } + + my_addr = std::string(argv[1]); +#endif + + const int port = 4242; + std::shared_ptr serverTransport; + std::shared_ptr transportFactory; + std::shared_ptr protocolFactory; + std::shared_ptr handler(new HelloHandler()); + std::shared_ptr processor(new HelloProcessor(handler)); + + if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) { + std::shared_ptr socketFactory(new TSSLSocketFactory()); + socketFactory->server(true); +#ifdef __ZEPHYR__ + static const char qemu_cert_pem[] = { +#include "qemu_cert.pem.inc" + }; + + static const char qemu_key_pem[] = { +#include "qemu_key.pem.inc" + }; + + static const char native_cert_pem[] = { +#include "native_cert.pem.inc" + }; + + socketFactory->loadCertificateFromBuffer(qemu_cert_pem); + socketFactory->loadPrivateKeyFromBuffer(qemu_key_pem); + socketFactory->loadTrustedCertificatesFromBuffer(native_cert_pem); +#else + socketFactory->loadCertificate(argv[2]); + socketFactory->loadPrivateKey(argv[3]); + socketFactory->loadTrustedCertificates(argv[4]); +#endif + serverTransport = + std::make_shared("0.0.0.0", port, socketFactory); + } else { + serverTransport = std::make_shared(my_addr, port); + } + + transportFactory = std::make_shared(); + if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) { + protocolFactory = std::make_shared(); + } else { + protocolFactory = std::make_shared(); + } + + TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory); + + try { + server.serve(); + } catch (std::exception &e) { + printf("caught exception: %s\n", e.what()); + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} diff --git a/tests/lib/thrift/ThriftTest/CMakeLists.txt b/tests/lib/thrift/ThriftTest/CMakeLists.txt new file mode 100644 index 000000000000..bbdeacb23e13 --- /dev/null +++ b/tests/lib/thrift/ThriftTest/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright 2022 Meta +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.20.0) +find_package(Zephyr REQUIRED HINTS $ENV{ZEPHYR_BASE}) +project(thrift_test) + +set(THRIFT_UPSTREAM ${ZEPHYR_THRIFT_MODULE_DIR}) + +include(${ZEPHYR_BASE}/modules/thrift/cmake/thrift.cmake) + +FILE(GLOB app_sources + src/*.cpp +) + +set(generated_sources "") +set(gen_dir ${ZEPHYR_BINARY_DIR}/misc/generated/thrift_ThriftTest) +list(APPEND generated_sources ${gen_dir}/gen-cpp/SecondService.cpp) +list(APPEND generated_sources ${gen_dir}/gen-cpp/SecondService.h) +list(APPEND generated_sources ${gen_dir}/gen-cpp/ThriftTest_constants.cpp) +list(APPEND generated_sources ${gen_dir}/gen-cpp/ThriftTest_constants.h) +list(APPEND generated_sources ${gen_dir}/gen-cpp/ThriftTest.cpp) +list(APPEND generated_sources ${gen_dir}/gen-cpp/ThriftTest.h) +list(APPEND generated_sources ${gen_dir}/gen-cpp/ThriftTest_types.cpp) +list(APPEND generated_sources ${gen_dir}/gen-cpp/ThriftTest_types.h) +list(APPEND app_sources ${generated_sources}) + +thrift( + app + cpp + :no_skeleton + ${gen_dir} + # v0.16: ubuntu packaged thrift compiler does not support 'uuid' type + "${THRIFT_UPSTREAM}/test/v0.16/ThriftTest.thrift" + "" + ${generated_sources} +) + +target_sources(app PRIVATE ${app_sources}) + +# needed because std::iterator was deprecated with -std=c++17 +target_compile_options(app PRIVATE -Wno-deprecated-declarations) + +# convert .pem files to array data at build time +zephyr_include_directories(${gen_dir}) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/qemu-cert.pem + ${gen_dir}/qemu_cert.pem.inc + ) + +generate_inc_file_for_target( + app + ${ZEPHYR_BASE}/samples/modules/thrift/hello/qemu-key.pem + ${gen_dir}/qemu_key.pem.inc + ) diff --git a/tests/lib/thrift/ThriftTest/Kconfig b/tests/lib/thrift/ThriftTest/Kconfig new file mode 100644 index 000000000000..51c97dda571c --- /dev/null +++ b/tests/lib/thrift/ThriftTest/Kconfig @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2022 Meta + +source "Kconfig.zephyr" + +config THRIFTTEST_SERVER_STACK_SIZE + int "ThriftTest Server stack size" + default 2048 + +config THRIFTTEST_LOG_LEVEL + int "ThriftTest log level" + default 4 + +config THRIFT_COMPACT_PROTOCOL + bool "Use TCompactProtocol for tests" + depends on THRIFT + default y + help + Enable this option to include TCompactProtocol in tests diff --git a/tests/lib/thrift/ThriftTest/overlay-tls.conf b/tests/lib/thrift/ThriftTest/overlay-tls.conf new file mode 100644 index 000000000000..2930b1fdc45c --- /dev/null +++ b/tests/lib/thrift/ThriftTest/overlay-tls.conf @@ -0,0 +1,10 @@ +CONFIG_THRIFT_SSL_SOCKET=y + +# TLS configuration +CONFIG_MBEDTLS=y +CONFIG_MBEDTLS_PEM_CERTIFICATE_FORMAT=y +CONFIG_MBEDTLS_ENABLE_HEAP=y +CONFIG_MBEDTLS_HEAP_SIZE=48000 +CONFIG_MBEDTLS_SSL_MAX_CONTENT_LEN=2048 +CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS=6 +CONFIG_NET_SOCKETS_SOCKOPT_TLS=y diff --git a/tests/lib/thrift/ThriftTest/prj.conf b/tests/lib/thrift/ThriftTest/prj.conf new file mode 100755 index 000000000000..35361bf9b72f --- /dev/null +++ b/tests/lib/thrift/ThriftTest/prj.conf @@ -0,0 +1,47 @@ +CONFIG_NEWLIB_LIBC=y +CONFIG_NEWLIB_LIBC_NANO=n + +# CONFIG_THRIFT Dependencies +CONFIG_CPP=y +CONFIG_STD_CPP17=y +CONFIG_CPP_EXCEPTIONS=y +CONFIG_GLIBCXX_LIBCPP=y +CONFIG_POSIX_API=y +CONFIG_NETWORKING=y +CONFIG_NET_TCP=y +CONFIG_NET_SOCKETS=y +CONFIG_NET_SOCKETPAIR=y +CONFIG_HEAP_MEM_POOL_SIZE=16384 +CONFIG_EVENTFD=y + +CONFIG_THRIFT=y + +# Test dependencies +CONFIG_ZTEST=y +CONFIG_ZTEST_NEW_API=y +CONFIG_TEST_RANDOM_GENERATOR=y +CONFIG_NET_TEST=y +CONFIG_NET_DRIVERS=y +CONFIG_NET_LOOPBACK=y + +# Some platforms require relatively large stack sizes. +# This can be tuned per-board. +CONFIG_ZTEST_STACK_SIZE=8192 +CONFIG_MAIN_STACK_SIZE=4096 +CONFIG_SYSTEM_WORKQUEUE_STACK_SIZE=8192 +CONFIG_THRIFTTEST_SERVER_STACK_SIZE=8192 +CONFIG_NET_TCP_WORKQ_STACK_SIZE=4096 +CONFIG_NET_MGMT_EVENT_STACK_SIZE=4096 +CONFIG_IDLE_STACK_SIZE=4096 +CONFIG_NET_RX_STACK_SIZE=8192 + +CONFIG_NET_BUF_TX_COUNT=20 +CONFIG_NET_PKT_TX_COUNT=20 +CONFIG_NET_BUF_RX_COUNT=20 +CONFIG_NET_PKT_RX_COUNT=20 +CONFIG_POSIX_MAX_FDS=16 + +# Network address config +CONFIG_NET_IPV4=y +CONFIG_NET_CONFIG_SETTINGS=y +CONFIG_NET_CONFIG_MY_IPV4_ADDR="192.0.2.1" diff --git a/tests/lib/thrift/ThriftTest/src/client.cpp b/tests/lib/thrift/ThriftTest/src/client.cpp new file mode 100644 index 000000000000..ea10b6fb7592 --- /dev/null +++ b/tests/lib/thrift/ThriftTest/src/client.cpp @@ -0,0 +1,246 @@ +/* + * Copyright 2022 Young Mei + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include + +#include + +#include "context.hpp" + +using namespace apache::thrift; + +using namespace std; + +static void init_Xtruct(Xtruct &s); + +ZTEST(thrift, test_void) +{ + context.client->testVoid(); +} + +ZTEST(thrift, test_string) +{ + string s; + context.client->testString(s, "Test"); + zassert_equal(s, "Test", ""); +} + +ZTEST(thrift, test_bool) +{ + zassert_equal(false, context.client->testBool(false), ""); + zassert_equal(true, context.client->testBool(true), ""); +} + +ZTEST(thrift, test_byte) +{ + zassert_equal(0, context.client->testByte(0), ""); + zassert_equal(-1, context.client->testByte(-1), ""); + zassert_equal(42, context.client->testByte(42), ""); + zassert_equal(-42, context.client->testByte(-42), ""); + zassert_equal(127, context.client->testByte(127), ""); + zassert_equal(-128, context.client->testByte(-128), ""); +} + +ZTEST(thrift, test_i32) +{ + zassert_equal(0, context.client->testI32(0), ""); + zassert_equal(-1, context.client->testI32(-1), ""); + zassert_equal(190000013, context.client->testI32(190000013), ""); + zassert_equal(-190000013, context.client->testI32(-190000013), ""); + zassert_equal(INT32_MAX, context.client->testI32(INT32_MAX), ""); + zassert_equal(INT32_MIN, context.client->testI32(INT32_MIN), ""); +} + +ZTEST(thrift, test_i64) +{ + zassert_equal(0, context.client->testI64(0), ""); + zassert_equal(-1, context.client->testI64(-1), ""); + zassert_equal(7000000000000000123LL, context.client->testI64(7000000000000000123LL), ""); + zassert_equal(-7000000000000000123LL, context.client->testI64(-7000000000000000123LL), ""); + zassert_equal(INT64_MAX, context.client->testI64(INT64_MAX), ""); + zassert_equal(INT64_MIN, context.client->testI64(INT64_MIN), ""); +} + +ZTEST(thrift, test_double) +{ + zassert_equal(0.0, context.client->testDouble(0.0), ""); + zassert_equal(-1.0, context.client->testDouble(-1.0), ""); + zassert_equal(-5.2098523, context.client->testDouble(-5.2098523), ""); + zassert_equal(-0.000341012439638598279, + context.client->testDouble(-0.000341012439638598279), ""); + zassert_equal(DBL_MAX, context.client->testDouble(DBL_MAX), ""); + zassert_equal(-DBL_MAX, context.client->testDouble(-DBL_MAX), ""); +} + +ZTEST(thrift, test_binary) +{ + string rsp; + + context.client->testBinary(rsp, ""); + zassert_equal("", rsp, ""); + context.client->testBinary(rsp, "Hello"); + zassert_equal("Hello", rsp, ""); + context.client->testBinary(rsp, "H\x03\x01\x01\x00"); + zassert_equal("H\x03\x01\x01\x00", rsp, ""); +} + +ZTEST(thrift, test_struct) +{ + Xtruct request_struct; + init_Xtruct(request_struct); + Xtruct response_struct; + context.client->testStruct(response_struct, request_struct); + + zassert_equal(response_struct, request_struct, NULL); +} + +ZTEST(thrift, test_nested_struct) +{ + Xtruct2 request_struct; + request_struct.byte_thing = 1; + init_Xtruct(request_struct.struct_thing); + request_struct.i32_thing = 5; + Xtruct2 response_struct; + context.client->testNest(response_struct, request_struct); + + zassert_equal(response_struct, request_struct, NULL); +} + +ZTEST(thrift, test_map) +{ + static const map request_map = { + {0, -10}, {1, -9}, {2, -8}, {3, -7}, {4, -6}}; + + map response_map; + context.client->testMap(response_map, request_map); + + zassert_equal(request_map, response_map, ""); +} + +ZTEST(thrift, test_string_map) +{ + static const map request_smap = { + {"a", "2"}, {"b", "blah"}, {"some", "thing"} + }; + map response_smap; + + context.client->testStringMap(response_smap, request_smap); + zassert_equal(response_smap, request_smap, ""); +} + +ZTEST(thrift, test_set) +{ + static const set request_set = {-2, -1, 0, 1, 2}; + + set response_set; + context.client->testSet(response_set, request_set); + + zassert_equal(request_set, response_set, ""); +} + +ZTEST(thrift, test_list) +{ + vector response_list; + context.client->testList(response_list, vector()); + zassert_true(response_list.empty(), "Unexpected list size: %llu", response_list.size()); + + static const vector request_list = {-2, -1, 0, 1, 2}; + + response_list.clear(); + context.client->testList(response_list, request_list); + zassert_equal(request_list, response_list, ""); +} + +ZTEST(thrift, test_enum) +{ + Numberz::type response = context.client->testEnum(Numberz::ONE); + zassert_equal(response, Numberz::ONE, NULL); + + response = context.client->testEnum(Numberz::TWO); + zassert_equal(response, Numberz::TWO, NULL); + + response = context.client->testEnum(Numberz::EIGHT); + zassert_equal(response, Numberz::EIGHT, NULL); +} + +ZTEST(thrift, test_typedef) +{ + UserId uid = context.client->testTypedef(309858235082523LL); + zassert_equal(uid, 309858235082523LL, "Unexpected uid: %llu", uid); +} + +ZTEST(thrift, test_nested_map) +{ + map> mm; + context.client->testMapMap(mm, 1); + + zassert_equal(mm.size(), 2, NULL); + zassert_equal(mm[-4][-4], -4, NULL); + zassert_equal(mm[-4][-3], -3, NULL); + zassert_equal(mm[-4][-2], -2, NULL); + zassert_equal(mm[-4][-1], -1, NULL); + zassert_equal(mm[4][4], 4, NULL); + zassert_equal(mm[4][3], 3, NULL); + zassert_equal(mm[4][2], 2, NULL); + zassert_equal(mm[4][1], 1, NULL); +} + +ZTEST(thrift, test_exception) +{ + std::exception_ptr eptr = nullptr; + + try { + context.client->testException("Xception"); + } catch (...) { + eptr = std::current_exception(); + } + zassert_not_equal(nullptr, eptr, "an exception was not thrown"); + + eptr = nullptr; + try { + context.client->testException("TException"); + } catch (...) { + eptr = std::current_exception(); + } + zassert_not_equal(nullptr, eptr, "an exception was not thrown"); + + context.client->testException("success"); +} + +ZTEST(thrift, test_multi_exception) +{ + std::exception_ptr eptr = nullptr; + + try { + Xtruct result; + context.client->testMultiException(result, "Xception", "test 1"); + } catch (...) { + eptr = std::current_exception(); + } + zassert_not_equal(nullptr, eptr, "an exception was not thrown"); + + eptr = nullptr; + try { + Xtruct result; + context.client->testMultiException(result, "Xception2", "test 2"); + } catch (...) { + eptr = std::current_exception(); + } + zassert_not_equal(nullptr, eptr, "an exception was not thrown"); +} + +static void init_Xtruct(Xtruct &s) +{ + s.string_thing = "Zero"; + s.byte_thing = 1; + s.i32_thing = -3; + s.i64_thing = -5; +} diff --git a/tests/lib/thrift/ThriftTest/src/context.hpp b/tests/lib/thrift/ThriftTest/src/context.hpp new file mode 100644 index 000000000000..39cf9e8590a6 --- /dev/null +++ b/tests/lib/thrift/ThriftTest/src/context.hpp @@ -0,0 +1,35 @@ +/* + * Copyright 2022 Young Mei + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef TESTS_LIB_THRIFT_THRIFTTEST_SRC_CONTEXT_HPP_ +#define TESTS_LIB_THRIFT_THRIFTTEST_SRC_CONTEXT_HPP_ + +#include + +#include + +#include + +#include "ThriftTest.h" + +using namespace apache::thrift::server; +using namespace thrift::test; + +struct ctx { + enum { + SERVER, + CLIENT, + }; + + std::array fds; + std::unique_ptr client; + std::unique_ptr server; + pthread_t server_thread; +}; + +extern ctx context; + +#endif /* TESTS_LIB_THRIFT_THRIFTTEST_SRC_CONTEXT_HPP_ */ diff --git a/tests/lib/thrift/ThriftTest/src/main.cpp b/tests/lib/thrift/ThriftTest/src/main.cpp new file mode 100644 index 000000000000..b8ac60480084 --- /dev/null +++ b/tests/lib/thrift/ThriftTest/src/main.cpp @@ -0,0 +1,166 @@ +/* + * Copyright 2022 Young Mei + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "context.hpp" +#include "server.hpp" +#include "thrift/server/TFDServer.h" + +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; + +ctx context; +static K_THREAD_STACK_DEFINE(ThriftTest_server_stack, CONFIG_THRIFTTEST_SERVER_STACK_SIZE); +static const char cert_pem[] = { +#include "qemu_cert.pem.inc" +}; +static const char key_pem[] = { +#include "qemu_key.pem.inc" +}; + +static void *server_func(void *arg) +{ + (void)arg; + + context.server->serve(); + + return nullptr; +} + +static void *thrift_test_setup(void) +{ + if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) { + TSSLSocketFactory socketFactory; + socketFactory.loadCertificateFromBuffer((const char *)&cert_pem[0]); + socketFactory.loadPrivateKeyFromBuffer((const char *)&key_pem[0]); + socketFactory.loadTrustedCertificatesFromBuffer((const char *)&cert_pem[0]); + } + + return NULL; +} + +static std::unique_ptr setup_client() +{ + std::shared_ptr transport; + std::shared_ptr protocol; + std::shared_ptr trans(new TFDTransport(context.fds[ctx::CLIENT])); + + if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) { + const int port = 4242; + std::shared_ptr socketFactory = + std::make_shared(); + socketFactory->authenticate(true); + trans = socketFactory->createSocket(CONFIG_NET_CONFIG_MY_IPV4_ADDR, port); + } else { + trans = std::make_shared(context.fds[ctx::CLIENT]); + } + + transport = std::make_shared(trans); + + if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) { + protocol = std::make_shared(transport); + } else { + protocol = std::make_shared(transport); + } + transport->open(); + return std::unique_ptr(new ThriftTestClient(protocol)); +} + +static std::unique_ptr setup_server() +{ + std::shared_ptr handler(new TestHandler()); + std::shared_ptr processor(new ThriftTestProcessor(handler)); + std::shared_ptr serverTransport; + std::shared_ptr protocolFactory; + std::shared_ptr transportFactory; + + if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) { + const int port = 4242; + std::shared_ptr socketFactory(new TSSLSocketFactory()); + socketFactory->server(true); + serverTransport = + std::make_shared("0.0.0.0", port, socketFactory); + } else { + serverTransport = std::make_shared(context.fds[ctx::SERVER]); + } + + transportFactory = std::make_shared(); + + if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) { + protocolFactory = std::make_shared(); + } else { + protocolFactory = std::make_shared(); + } + TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory); + return std::unique_ptr( + new TSimpleServer(processor, serverTransport, transportFactory, protocolFactory)); +} + +static void thrift_test_before(void *data) +{ + ARG_UNUSED(data); + int rv; + + pthread_attr_t attr; + pthread_attr_t *attrp = &attr; + + if (IS_ENABLED(CONFIG_ARCH_POSIX)) { + attrp = NULL; + } else { + rv = pthread_attr_init(attrp); + zassert_equal(0, rv, "pthread_attr_init failed: %d", rv); + rv = pthread_attr_setstack(attrp, ThriftTest_server_stack, + CONFIG_THRIFTTEST_SERVER_STACK_SIZE); + zassert_equal(0, rv, "pthread_attr_setstack failed: %d", rv); + } + + // create the communication channel + rv = socketpair(AF_UNIX, SOCK_STREAM, 0, &context.fds.front()); + zassert_equal(0, rv, "socketpair failed: %d\n", rv); + + // set up server + context.server = setup_server(); + + // start the server + rv = pthread_create(&context.server_thread, attrp, server_func, nullptr); + zassert_equal(0, rv, "pthread_create failed: %d", rv); + + // set up client + context.client = setup_client(); +} + +static void thrift_test_after(void *data) +{ + ARG_UNUSED(data); + void *unused; + + context.server->stop(); + + pthread_join(context.server_thread, &unused); + + context.server.reset(); + context.client.reset(); + + for (auto &fd : context.fds) { + close(fd); + fd = -1; + } +} + +ZTEST_SUITE(thrift, NULL, thrift_test_setup, thrift_test_before, thrift_test_after, NULL); diff --git a/tests/lib/thrift/ThriftTest/src/server.hpp b/tests/lib/thrift/ThriftTest/src/server.hpp new file mode 100644 index 000000000000..9f1ac92627ca --- /dev/null +++ b/tests/lib/thrift/ThriftTest/src/server.hpp @@ -0,0 +1,325 @@ +/* + * Copyright 2022 Young Mei + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +#include +#include + +#include "SecondService.h" +#include "ThriftTest.h" + +using namespace std; + +using namespace apache::thrift; +using namespace apache::thrift::transport; +using namespace thrift::test; + +class TestHandler : public ThriftTestIf +{ +public: + TestHandler() = default; + + void testVoid() override + { + printf("testVoid()\n"); + } + + void testString(string &out, const string &thing) override + { + printf("testString(\"%s\")\n", thing.c_str()); + out = thing; + } + + bool testBool(const bool thing) override + { + printf("testBool(%s)\n", thing ? "true" : "false"); + return thing; + } + + int8_t testByte(const int8_t thing) override + { + printf("testByte(%d)\n", (int)thing); + return thing; + } + + int32_t testI32(const int32_t thing) override + { + printf("testI32(%d)\n", thing); + return thing; + } + + int64_t testI64(const int64_t thing) override + { + printf("testI64(%" PRId64 ")\n", thing); + return thing; + } + + double testDouble(const double thing) override + { + printf("testDouble(%f)\n", thing); + return thing; + } + + void testBinary(std::string &_return, const std::string &thing) override + { + std::ostringstream hexstr; + + hexstr << std::hex << thing; + printf("testBinary(%lu: %s)\n", safe_numeric_cast(thing.size()), + hexstr.str().c_str()); + _return = thing; + } + + void testStruct(Xtruct &out, const Xtruct &thing) override + { + printf("testStruct({\"%s\", %d, %d, %" PRId64 "})\n", thing.string_thing.c_str(), + (int)thing.byte_thing, thing.i32_thing, thing.i64_thing); + out = thing; + } + + void testNest(Xtruct2 &out, const Xtruct2 &nest) override + { + const Xtruct &thing = nest.struct_thing; + + printf("testNest({%d, {\"%s\", %d, %d, %" PRId64 "}, %d})\n", (int)nest.byte_thing, + thing.string_thing.c_str(), (int)thing.byte_thing, thing.i32_thing, + thing.i64_thing, nest.i32_thing); + out = nest; + } + + void testMap(map &out, const map &thing) override + { + map::const_iterator m_iter; + bool first = true; + + printf("testMap({"); + for (m_iter = thing.begin(); m_iter != thing.end(); ++m_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + + printf("%d => %d", m_iter->first, m_iter->second); + } + + printf("})\n"); + out = thing; + } + + void testStringMap(map &out, + const map &thing) override + { + map::const_iterator m_iter; + bool first = true; + + printf("testMap({"); + for (m_iter = thing.begin(); m_iter != thing.end(); ++m_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%s => %s", (m_iter->first).c_str(), (m_iter->second).c_str()); + } + + printf("})\n"); + out = thing; + } + + void testSet(set &out, const set &thing) override + { + set::const_iterator s_iter; + bool first = true; + + printf("testSet({"); + for (s_iter = thing.begin(); s_iter != thing.end(); ++s_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + + printf("%d", *s_iter); + } + + printf("})\n"); + out = thing; + } + + void testList(vector &out, const vector &thing) override + { + vector::const_iterator l_iter; + bool first = true; + + printf("testList({"); + for (l_iter = thing.begin(); l_iter != thing.end(); ++l_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d", *l_iter); + } + + printf("})\n"); + out = thing; + } + + Numberz::type testEnum(const Numberz::type thing) override + { + printf("testEnum(%d)\n", thing); + return thing; + } + + UserId testTypedef(const UserId thing) override + { + printf("testTypedef(%" PRId64 ")\n", thing); + return thing; + } + + void testMapMap(map> &mapmap, const int32_t hello) override + { + map pos; + map neg; + + printf("testMapMap(%d)\n", hello); + for (int i = 1; i < 5; i++) { + pos.insert(make_pair(i, i)); + neg.insert(make_pair(-i, -i)); + } + + mapmap.insert(make_pair(4, pos)); + mapmap.insert(make_pair(-4, neg)); + } + + void testInsanity(map> &insane, + const Insanity &argument) override + { + Insanity looney; + map first_map; + map second_map; + + first_map.insert(make_pair(Numberz::TWO, argument)); + first_map.insert(make_pair(Numberz::THREE, argument)); + + second_map.insert(make_pair(Numberz::SIX, looney)); + + insane.insert(make_pair(1, first_map)); + insane.insert(make_pair(2, second_map)); + + printf("testInsanity()\n"); + printf("return"); + printf(" = {"); + map>::const_iterator i_iter; + + for (i_iter = insane.begin(); i_iter != insane.end(); ++i_iter) { + printf("%" PRId64 " => {", i_iter->first); + map::const_iterator i2_iter; + + for (i2_iter = i_iter->second.begin(); i2_iter != i_iter->second.end(); + ++i2_iter) { + printf("%d => {", i2_iter->first); + map userMap = i2_iter->second.userMap; + map::const_iterator um; + + printf("{"); + for (um = userMap.begin(); um != userMap.end(); ++um) { + printf("%d => %" PRId64 ", ", um->first, um->second); + } + + printf("}, "); + vector xtructs = i2_iter->second.xtructs; + vector::const_iterator x; + + printf("{"); + for (x = xtructs.begin(); x != xtructs.end(); ++x) { + printf("{\"%s\", %d, %d, %" PRId64 "}, ", + x->string_thing.c_str(), (int)x->byte_thing, + x->i32_thing, x->i64_thing); + } + + printf("}"); + printf("}, "); + } + + printf("}, "); + } + + printf("}\n"); + } + + void testMulti(Xtruct &hello, const int8_t arg0, const int32_t arg1, const int64_t arg2, + const std::map &arg3, const Numberz::type arg4, + const UserId arg5) override + { + (void)arg3; + (void)arg4; + (void)arg5; + printf("testMulti()\n"); + hello.string_thing = "Hello2"; + hello.byte_thing = arg0; + hello.i32_thing = arg1; + hello.i64_thing = (int64_t)arg2; + } + + void testException(const std::string &arg) override + { + printf("testException(%s)\n", arg.c_str()); + if (arg.compare("Xception") == 0) { + Xception e; + e.errorCode = 1001; + e.message = arg; + throw e; + } else if (arg.compare("TException") == 0) { + apache::thrift::TException e; + throw e; + } else { + Xtruct result; + result.string_thing = arg; + return; + } + } + + void testMultiException(Xtruct &result, const std::string &arg0, + const std::string &arg1) override + { + + printf("testMultiException(%s, %s)\n", arg0.c_str(), arg1.c_str()); + if (arg0.compare("Xception") == 0) { + Xception e; + e.errorCode = 1001; + e.message = "This is an Xception"; + throw e; + } else if (arg0.compare("Xception2") == 0) { + Xception2 e; + e.errorCode = 2002; + e.struct_thing.string_thing = "This is an Xception2"; + throw e; + } else { + result.string_thing = arg1; + return; + } + } + + void testOneway(const int32_t aNum) override + { + printf("testOneway(%d): call received\n", aNum); + } +}; + +class SecondHandler : public SecondServiceIf +{ +public: + void secondtestString(std::string &result, const std::string &thing) override + { + result = "testString(\"" + thing + "\")"; + } +}; diff --git a/tests/lib/thrift/ThriftTest/testcase.yaml b/tests/lib/thrift/ThriftTest/testcase.yaml new file mode 100644 index 000000000000..1ec2302e56f3 --- /dev/null +++ b/tests/lib/thrift/ThriftTest/testcase.yaml @@ -0,0 +1,15 @@ +common: + tags: thrift cpp newlib + modules: + - thrift + filter: TOOLCHAIN_HAS_NEWLIB == 1 + # qemu_x86 exluded due to missing long double functions in SDK + # See https://github.com/zephyrproject-rtos/sdk-ng/issues/603 + platform_allow: mps2_an385 qemu_cortex_a53 qemu_riscv32 qemu_riscv64 qemu_x86_64 +tests: + thrift.ThriftTest.newlib.binaryProtocol: {} + thrift.ThriftTest.newlib.compactProtocol: + extra_configs: + - CONFIG_THRIFT_COMPACT_PROTOCOL=y + thrift.ThriftTest.newlib.tlsTransport: + extra_args: OVERLAY_CONFIG="overlay-tls.conf" diff --git a/west.yml b/west.yml index f0ea7b86bf55..56c85cb393c9 100644 --- a/west.yml +++ b/west.yml @@ -259,6 +259,9 @@ manifest: - name: zscilib path: modules/lib/zscilib revision: 0035be5e6a45e4ab89755b176d305d7a877fc79c + - name: thrift + path: modules/lib/thrift + revision: 10023645a0e6cb7ce23fcd7fd3dbac9f18df6234 self: path: zephyr