From 7c83d34d2473fb4bf9d97a8519da44206dcb07a7 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Fri, 31 Jan 2025 17:02:09 +0530 Subject: [PATCH 01/13] s390x vector intrinsics support for vLLM Signed-off-by: Dilip Gowda Bhagavan --- cmake/cpu_extension.cmake | 11 +- csrc/cpu/attention.cpp | 4 +- csrc/cpu/cpu_types.hpp | 3 + csrc/cpu/cpu_types_vxe.cpp | 480 +++++++++++++++++++++++++++++++++++++ csrc/cpu/quant.cpp | 2 +- 5 files changed, 496 insertions(+), 4 deletions(-) create mode 100644 csrc/cpu/cpu_types_vxe.cpp diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 714abca2a5f..14186fd66be 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -81,6 +81,7 @@ else() find_isa(${CPUINFO} "POWER9" POWER9_FOUND) find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support + find_isa(${CPUINFO} "S390" S390_FOUND) endif() @@ -129,8 +130,16 @@ elseif (ASIMD_FOUND) elseif(APPLE_SILICON_FOUND) message(STATUS "Apple Silicon Detected") set(ENABLE_NUMA OFF) + elseif (S390_FOUND) + message(STATUS "S390 detected") + # Check for S390 VXE support + list(APPEND CXX_COMPILE_FLAGS + "-mvx" + "-mzvector" + "-march=native" + "-mtune=native") else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA or ARMv8 support.") endif() # diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index b9764056e8a..0257d8ff16b 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -24,8 +24,8 @@ struct KernelVecType { template <> struct KernelVecType { -#ifdef __powerpc64__ - // Power architecture-specific vector types +#if defined(__powerpc64__) || defined(__s390x__) + // Power and s390x architecture-specific vector types using q_load_vec_type = vec_op::FP32Vec8; using k_load_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::FP32Vec16; diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index a7181510613..17bbe04eef9 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -7,6 +7,9 @@ #elif defined(__POWER9_VECTOR__) // ppc implementation #include "cpu_types_vsx.hpp" +#elif defined(__s390x__) + // s390 implementation + #include "cpu_types_vxe.hpp" #elif defined(__aarch64__) // arm implementation #include "cpu_types_arm.hpp" diff --git a/csrc/cpu/cpu_types_vxe.cpp b/csrc/cpu/cpu_types_vxe.cpp new file mode 100644 index 00000000000..ab8cbbbf4ec --- /dev/null +++ b/csrc/cpu/cpu_types_vxe.cpp @@ -0,0 +1,480 @@ + +#ifndef CPU_TYPES_VXE_HPP +#define CPU_TYPES_VXE_HPP + +#include +#include +#include +namespace vec_op { + +#define vec_neg(a) (-(a)) +#define vec_add(a, b) ((a) + (b)) +#define vec_sub(a, b) ((a) - (b)) +#define vec_mul(a, b) ((a) * (b)) +#define vec_div(a, b) ((a) / (b)) +#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic +#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) +#else + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F&& f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template +struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +typedef struct ss16x8x2_t { + __vector signed short val[2]; +} ss16x8x2_t; + +typedef struct ss16x8x4_t { + __vector signed short val[4]; +} ss16x8x4_t; + +typedef struct f32x4x2_t { + __vector float val[2]; +} f32x4x2_t; + +typedef struct f32x4x4_t { + __vector float val[4]; +} f32x4x4_t; + +struct FP32Vec8; +struct FP32Vec16; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __vector signed short reg; + + explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {} + explicit BF16Vec8(const FP32Vec8&); + + void save(void* ptr) const { + *reinterpret_cast<__vector signed short*>(ptr) = reg; + } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + ss16x8x2_t reg; + + explicit BF16Vec16(const void* ptr) { + // Load 256 bits in two parts + reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); + } + + explicit BF16Vec16(const FP32Vec16&); + + void save(void* ptr) const { + // Save 256 bits in two parts + vec_xst(reg.val[0], 0, (signed short*)ptr); + vec_xst(reg.val[1], 16, (signed short*)ptr); + } +}; + +const static __vector signed short zero = vec_splats((signed short)0); + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + ss16x8x4_t reg; + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} + + explicit BF16Vec32(const BF16Vec8& vec8_data) + : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __vector float reg; + float values[VEC_ELEM_NUM]; + }; + + __vector float reg; + + explicit FP32Vec4(float v) : reg(vec_splats(v)) {} + + explicit FP32Vec4() : reg(vec_splats(0.0f)) {} + + explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} + + explicit FP32Vec4(__vector float data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + f32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x2_t reg; + + explicit FP32Vec8(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + } + + explicit FP32Vec8() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + } + + explicit FP32Vec8(const float* ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + } + + explicit FP32Vec8(f32x4x2_t data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + } + + explicit FP32Vec8(const BF16Vec8& v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::exp(ar.values[0]); + ret.val[0][1] = std::exp(ar.values[1]); + ret.val[0][2] = std::exp(ar.values[2]); + ret.val[0][3] = std::exp(ar.values[3]); + ret.val[1][0] = std::exp(ar.values[4]); + ret.val[1][1] = std::exp(ar.values[5]); + ret.val[1][2] = std::exp(ar.values[6]); + ret.val[1][3] = std::exp(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 tanh() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::tanh(ar.values[0]); + ret.val[0][1] = std::tanh(ar.values[1]); + ret.val[0][2] = std::tanh(ar.values[2]); + ret.val[0][3] = std::tanh(ar.values[3]); + ret.val[1][0] = std::tanh(ar.values[4]); + ret.val[1][1] = std::tanh(ar.values[5]); + ret.val[1][2] = std::tanh(ar.values[6]); + ret.val[1][3] = std::tanh(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 er() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::erf(ar.values[0]); + ret.val[0][1] = std::erf(ar.values[1]); + ret.val[0][2] = std::erf(ar.values[2]); + ret.val[0][3] = std::erf(ar.values[3]); + ret.val[1][0] = std::erf(ar.values[4]); + ret.val[1][1] = std::erf(ar.values[5]); + ret.val[1][2] = std::erf(ar.values[6]); + ret.val[1][3] = std::erf(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 operator*(const FP32Vec8& b) const { + return FP32Vec8( + {vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator+(const FP32Vec8& b) const { + return FP32Vec8( + {vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator-(const FP32Vec8& b) const { + return FP32Vec8( + {vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator/(const FP32Vec8& b) const { + return FP32Vec8( + {vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + } + + void save(float* ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + f32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x4_t reg; + + explicit FP32Vec16(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + reg.val[2] = vec_splats(v); + reg.val[3] = vec_splats(v); + } + + explicit FP32Vec16() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + reg.val[2] = vec_splats(0.0f); + reg.val[3] = vec_splats(0.0f); + } + + explicit FP32Vec16(const float* ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + reg.val[2] = vec_xl(32, ptr); + reg.val[3] = vec_xl(48, ptr); + } + + explicit FP32Vec16(f32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[2]; + reg.val[3] = data.reg.val[3]; + } + + explicit FP32Vec16(const FP32Vec4& data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + } + + explicit FP32Vec16(const FP32Vec8& data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const BF16Vec16& v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); + reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); + reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); + } + + explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator+(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator-(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator/(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + template + float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float result = 0; + const int start = idx * group_size; + unroll_loop( + [&result, &start, ar](int i) { result += ar.values[start + i]; }); + + return result; + } + + void save(float* ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + vec_xst(reg.val[2], 32, ptr); + vec_xst(reg.val[3], 48, ptr); + } +}; + +template +struct VecType { + using vec_type = void; +}; + +template +using vec_t = typename VecType::vec_type; + +template <> +struct VecType { + using vec_type = FP32Vec8; +}; + +template <> +struct VecType { + using vec_type = BF16Vec8; +}; + +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} + +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { + acc = acc + a * b; +} + +namespace c10 { +struct BFloat16 { + uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit + // value. +}; +} // namespace c10 + +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifndef __VEC_CLASS_FP_NAN + #define __VEC_CLASS_FP_NAN (1 << 6) +#endif + +const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15, + 18, 19, 22, 23, 26, 27, 30, 31}; +const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, + 0x00007fff}; +const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, + 0x7fc00000}; +const static __vector unsigned int sh16 = {16, 16, 16, 16}; +const static __vector unsigned int one = {1, 1, 1, 1}; + +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + int cc; + __vector __bool int sel0 = + vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel1 = + vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc); + inp0 = vec_sel(inp0, nan, sel0) >> sh16; + inp1 = vec_sel(inp1, nan, sel1) >> sh16; + reg = (__vector signed short)vec_perm(inp0, inp1, omask); +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); + __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); + int cc; + __vector __bool int sel0 = + vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel1 = + vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel2 = + vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc); + __vector __bool int sel3 = + vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc); + inp0 = vec_sel(inp0, nan, sel0) >> sh16; + inp1 = vec_sel(inp1, nan, sel1) >> sh16; + inp2 = vec_sel(inp2, nan, sel2) >> sh16; + inp3 = vec_sel(inp3, nan, sel3) >> sh16; + reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); + reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); +} + +inline void prefetch(const void* addr) { void __dcbt(const void* addr); } + +}; // namespace vec_op + +#endif \ No newline at end of file diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 33b16378328..6751e7e55fc 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -25,7 +25,7 @@ struct KernelVecType { template <> struct KernelVecType { -#ifdef __powerpc64__ +#if defined(__powerpc64__) || defined(__s390x__) // Power architecture-specific vector type using load_vec_type = vec_op::FP32Vec16; #else From 561eca879cb077c3e0bdf1a9752af7bf42b81be3 Mon Sep 17 00:00:00 2001 From: Rishika Kedia Date: Fri, 31 Jan 2025 17:15:28 +0530 Subject: [PATCH 02/13] dockerising and requirements-cpu for vllm Signed-off-by: Rishika Kedia --- Dockerfile.s390x | 91 ++++++++++++++++++++++++++++++++++++++++++++ requirements-cpu.txt | 12 +++--- 2 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 Dockerfile.s390x diff --git a/Dockerfile.s390x b/Dockerfile.s390x new file mode 100644 index 00000000000..c182f6ec927 --- /dev/null +++ b/Dockerfile.s390x @@ -0,0 +1,91 @@ +# Dockerfile for vLLM on s390x architecture with dependencies, PyTorch, torchvision, and Apache Arrow + +FROM ubuntu:22.04 +USER root + + +RUN apt-get update -y && apt-get install -y \ + git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev protobuf-compiler \ + build-essential ffmpeg libsm6 libxext6 libgl1 python3 python3-pip cmake ninja-build \ + cargo libjpeg-dev libpng-dev zlib1g-dev libavcodec-dev libavformat-dev libswscale-dev \ + libtiff-dev libwebp-dev llvm-dev libclang-dev clang libssl-dev g++ \ + python3-distutils python3-setuptools libbz2-dev liblz4-dev libzstd-dev \ + libsnappy-dev rapidjson-dev libboost-dev liborc-dev pkg-config libopenblas-dev + + +RUN python3 -m pip install --upgrade pip setuptools + + +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ + . "$HOME/.cargo/env" + + +WORKDIR /tmp/arrow +RUN git clone https://github.com/apache/arrow.git && \ + cd arrow/cpp && \ + mkdir release && cd release && \ + cmake -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/usr/local \ + -DARROW_PYTHON=ON \ + -DARROW_PARQUET=ON \ + -DARROW_ORC=ON \ + -DARROW_FILESYSTEM=ON \ + -DARROW_WITH_LZ4=ON \ + -DARROW_WITH_ZSTD=ON \ + -DARROW_WITH_SNAPPY=ON \ + -DARROW_JSON=ON \ + -DARROW_CSV=ON \ + -DPROTOBUF_PROTOC_EXECUTABLE=/usr/bin/protoc \ + -DARROW_DEPENDENCY_SOURCE=BUNDLED \ + .. && \ + make -j$(nproc) && make install + + +ENV CMAKE_PREFIX_PATH=/usr/local/lib/cmake:$CMAKE_PREFIX_PATH +ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + + +RUN pip install --no-cache-dir https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241201%2Bcpu-cp310-cp310-linux_s390x.whl + + +RUN git clone https://github.com/pytorch/vision.git && \ + cd vision && \ + git checkout 48b1edf && \ + python3 setup.py bdist_wheel && \ + pip install --no-cache-dir dist/*.whl && \ + rm -rf dist + + +COPY . /workspace/vllm +WORKDIR /workspace/vllm + + +RUN git config --global --add safe.directory /workspace/vllm && \ + git rev-parse --is-inside-work-tree || git init && \ + git config --global user.email "docker@build.local" && \ + git config --global user.name "Docker Builder" && \ + git add . && git commit -m "Initial commit" --allow-empty + + +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi + + +RUN mkdir -p /root/.cache/pip && chmod -R 777 /root/.cache/pip + +RUN pip install --no-cache-dir -v \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ + -r requirements-cpu.txt \ + xformers uvloop==0.21.0 + +RUN python3 -m pip install --upgrade setuptools + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/ccache \ + VLLM_TARGET_DEVICE=cpu python3 setup.py build_ext --inplace && \ + VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ + pip install --no-cache-dir dist/*.whl && \ + rm -rf dist + +ENTRYPOINT ["/usr/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file diff --git a/requirements-cpu.txt b/requirements-cpu.txt index ecfa822e011..7ca3a5ed210 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,14 +2,14 @@ -r requirements-common.txt # Dependencies for CPUs -torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" -torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" +torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_machine != "s390x" and platform_system != "Darwin" +torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_machine != "s390x" or platform_system == "Darwin" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch -torchaudio; platform_machine != "ppc64le" -torchaudio==2.5.1; platform_machine == "ppc64le" +torchaudio; platform_machine != "ppc64le" or platform_machine != "s390x" +torchaudio==2.5.1; platform_machine == "ppc64le" or platform_machine != "s390x" # required for the image processor of phi3v, this must be updated alongside torch -torchvision; platform_machine != "ppc64le" -torchvision==0.20.1; platform_machine == "ppc64le" +torchvision; platform_machine != "ppc64le" or platform_machine != "s390x" +torchvision==0.20.1; platform_machine == "ppc64le" or platform_machine != "s390x" datasets # for benchmark scripts From 0fc4f37949c112c592e43601e3b926414ada6f42 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Fri, 31 Jan 2025 18:31:19 +0530 Subject: [PATCH 03/13] corrected extension for cpu_types_vxe.cpp to cpu_types_vxe.hpp Signed-off-by: Dilip Gowda Bhagavan --- csrc/cpu/{cpu_types_vxe.cpp => cpu_types_vxe.hpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename csrc/cpu/{cpu_types_vxe.cpp => cpu_types_vxe.hpp} (100%) diff --git a/csrc/cpu/cpu_types_vxe.cpp b/csrc/cpu/cpu_types_vxe.hpp similarity index 100% rename from csrc/cpu/cpu_types_vxe.cpp rename to csrc/cpu/cpu_types_vxe.hpp From edd02fa285859c838451b14db7bb69c8e5006517 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Wed, 5 Feb 2025 18:15:58 +0530 Subject: [PATCH 04/13] addressing review comments Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 44 +++++++++++++-------------------------- cmake/cpu_extension.cmake | 2 +- requirements-cpu.txt | 13 ++++++------ 3 files changed, 22 insertions(+), 37 deletions(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index c182f6ec927..75c422e46cc 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -1,9 +1,7 @@ # Dockerfile for vLLM on s390x architecture with dependencies, PyTorch, torchvision, and Apache Arrow - FROM ubuntu:22.04 USER root - RUN apt-get update -y && apt-get install -y \ git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev protobuf-compiler \ build-essential ffmpeg libsm6 libxext6 libgl1 python3 python3-pip cmake ninja-build \ @@ -12,14 +10,11 @@ RUN apt-get update -y && apt-get install -y \ python3-distutils python3-setuptools libbz2-dev liblz4-dev libzstd-dev \ libsnappy-dev rapidjson-dev libboost-dev liborc-dev pkg-config libopenblas-dev - RUN python3 -m pip install --upgrade pip setuptools - RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ . "$HOME/.cargo/env" - WORKDIR /tmp/arrow RUN git clone https://github.com/apache/arrow.git && \ cd arrow/cpp && \ @@ -37,47 +32,36 @@ RUN git clone https://github.com/apache/arrow.git && \ -DARROW_CSV=ON \ -DPROTOBUF_PROTOC_EXECUTABLE=/usr/bin/protoc \ -DARROW_DEPENDENCY_SOURCE=BUNDLED \ + -DARROW_DATASET=libarrow_dataset.so \ .. && \ make -j$(nproc) && make install - ENV CMAKE_PREFIX_PATH=/usr/local/lib/cmake:$CMAKE_PREFIX_PATH ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH - -RUN pip install --no-cache-dir https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241201%2Bcpu-cp310-cp310-linux_s390x.whl - - -RUN git clone https://github.com/pytorch/vision.git && \ - cd vision && \ - git checkout 48b1edf && \ - python3 setup.py bdist_wheel && \ - pip install --no-cache-dir dist/*.whl && \ - rm -rf dist - - COPY . /workspace/vllm WORKDIR /workspace/vllm - -RUN git config --global --add safe.directory /workspace/vllm && \ - git rev-parse --is-inside-work-tree || git init && \ - git config --global user.email "docker@build.local" && \ - git config --global user.name "Docker Builder" && \ - git add . && git commit -m "Initial commit" --allow-empty - - ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi - RUN mkdir -p /root/.cache/pip && chmod -R 777 /root/.cache/pip -RUN pip install --no-cache-dir -v \ +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/ccache \ + pip install --no-cache-dir -v \ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ - -r requirements-cpu.txt \ - xformers uvloop==0.21.0 + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + -r requirements-cpu.txt + +ARG TORCH_VISION_VERSION=v0.20.1 +RUN git clone https://github.com/pytorch/vision.git && \ + cd vision && \ + git checkout $TORCH_VISION_VERSION && \ + python3 setup.py bdist_wheel && \ + pip install --no-cache-dir dist/*.whl && \ + rm -rf dist RUN python3 -m pip install --upgrade setuptools diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 14186fd66be..ca2ffb1bc3c 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -130,7 +130,7 @@ elseif (ASIMD_FOUND) elseif(APPLE_SILICON_FOUND) message(STATUS "Apple Silicon Detected") set(ENABLE_NUMA OFF) - elseif (S390_FOUND) +elseif (S390_FOUND) message(STATUS "S390 detected") # Check for S390 VXE support list(APPEND CXX_COMPILE_FLAGS diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 7ca3a5ed210..3f633a0ace0 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,14 +2,15 @@ -r requirements-common.txt # Dependencies for CPUs -torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_machine != "s390x" and platform_system != "Darwin" -torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_machine != "s390x" or platform_system == "Darwin" +torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" and platform_machine != "s390x" +torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" +torch==2.6.0.dev20241212+cpu; platform_machine == "s390x" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch -torchaudio; platform_machine != "ppc64le" or platform_machine != "s390x" -torchaudio==2.5.1; platform_machine == "ppc64le" or platform_machine != "s390x" +torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" +torchaudio==2.5.1; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch -torchvision; platform_machine != "ppc64le" or platform_machine != "s390x" -torchvision==0.20.1; platform_machine == "ppc64le" or platform_machine != "s390x" +torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" +torchvision==0.20.1; platform_machine == "ppc64le" datasets # for benchmark scripts From c22c9da29e977cf901f4518d027cce76d757628d Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Tue, 11 Feb 2025 14:43:01 +0530 Subject: [PATCH 05/13] change base image to ubi-minimal and incorporate review comments Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 131 ++++++++++++++++++++++++++++++++++--------- requirements-cpu.txt | 2 +- 2 files changed, 105 insertions(+), 28 deletions(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index 75c422e46cc..6f8abdac672 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -1,22 +1,45 @@ -# Dockerfile for vLLM on s390x architecture with dependencies, PyTorch, torchvision, and Apache Arrow -FROM ubuntu:22.04 -USER root +# Base UBI image for s390x architecture +ARG BASE_UBI_IMAGE_TAG=9.5-1736404155 +FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} as base -RUN apt-get update -y && apt-get install -y \ - git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev protobuf-compiler \ - build-essential ffmpeg libsm6 libxext6 libgl1 python3 python3-pip cmake ninja-build \ - cargo libjpeg-dev libpng-dev zlib1g-dev libavcodec-dev libavformat-dev libswscale-dev \ - libtiff-dev libwebp-dev llvm-dev libclang-dev clang libssl-dev g++ \ - python3-distutils python3-setuptools libbz2-dev liblz4-dev libzstd-dev \ - libsnappy-dev rapidjson-dev libboost-dev liborc-dev pkg-config libopenblas-dev +# Install basic dependencies +RUN microdnf -y update && microdnf install -y \ + python-pip python-wheel \ + && microdnf clean all -RUN python3 -m pip install --upgrade pip setuptools +WORKDIR /workspace +ENV LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 + +# Install development utilities +RUN microdnf install -y \ + which procps findutils tar vim git gcc g++ make patch make cython cargo zlib-devel \ + libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ + openssl-devel openblas openblas-devel wget autoconf automake libtool && \ + microdnf clean all + +# Python Installation +FROM base as python-install + +RUN microdnf install -y python-devel && microdnf clean all + +# Set up Python virtual environment +RUN python -m venv /opt/venv/vllm +ENV PATH="/opt/venv/vllm/bin:$PATH" + +# Upgrade pip and install base tools +RUN python -m pip install --no-cache -U pip wheel uv cmake setuptools + +# Install Rust RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ . "$HOME/.cargo/env" -WORKDIR /tmp/arrow -RUN git clone https://github.com/apache/arrow.git && \ +# Build Apache Arrow +workdir /tmp +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/ccache \ + git clone https://github.com/apache/arrow.git && \ cd arrow/cpp && \ mkdir release && cd release && \ cmake -DCMAKE_BUILD_TYPE=Release \ @@ -30,46 +53,100 @@ RUN git clone https://github.com/apache/arrow.git && \ -DARROW_WITH_SNAPPY=ON \ -DARROW_JSON=ON \ -DARROW_CSV=ON \ + -DARROW_DATASET=ON \ -DPROTOBUF_PROTOC_EXECUTABLE=/usr/bin/protoc \ -DARROW_DEPENDENCY_SOURCE=BUNDLED \ - -DARROW_DATASET=libarrow_dataset.so \ .. && \ - make -j$(nproc) && make install + make -j$(nproc) && \ + make install && \ + cd ../../python && \ + export PYARROW_PARALLEL=4 && \ + export ARROW_BUILD_TYPE=release && \ + python -m pip install -r requirements-build.txt && \ + python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel + -ENV CMAKE_PREFIX_PATH=/usr/local/lib/cmake:$CMAKE_PREFIX_PATH -ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH +# Install numactl (needed for numa.h dependency) +WORKDIR /tmp +RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz && \ + tar -xvzf v2.0.16.tar.gz && \ + cd numactl-2.0.16 && \ + ./autogen.sh && \ + ./configure && \ + make && \ + make install +# Clean up build files +RUN rm -rf /tmp/numactl-2.0.16 /tmp/v2.0.16.tar.gz + +# Set include path +ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" + +# Copy vLLM source COPY . /workspace/vllm WORKDIR /workspace/vllm +# Check git repository integrity if enabled ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi -RUN mkdir -p /root/.cache/pip && chmod -R 777 /root/.cache/pip - +# Install dependencies, including PyTorch and Apache Arrow RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/ccache \ pip install --no-cache-dir -v \ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ + /tmp/arrow/python/dist/*.whl \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ -r requirements-cpu.txt +#Clean up build files for arrow +RUN rm -rf /tmp/arrow + +# Install torchvision ARG TORCH_VISION_VERSION=v0.20.1 +WORKDIR /tmp RUN git clone https://github.com/pytorch/vision.git && \ cd vision && \ git checkout $TORCH_VISION_VERSION && \ - python3 setup.py bdist_wheel && \ - pip install --no-cache-dir dist/*.whl && \ - rm -rf dist + python setup.py bdist_wheel && \ + pip install --no-cache-dir dist/*.whl + +#Clean up build files for vision +RUN rm -rf /tmp/vision + +# Final build stage +FROM python-install as vllm-cpu + +# Ensure we are using the virtual environment +ENV PATH="/opt/venv/vllm/bin:$PATH" + +# Set correct library path for torch and numactl +ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python3.9/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" -RUN python3 -m pip install --upgrade setuptools +# Upgrade setuptools for compatibility +RUN python -m pip install --upgrade setuptools +WORKDIR /workspace/vllm + +# Build and install vllm RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/ccache \ - VLLM_TARGET_DEVICE=cpu python3 setup.py build_ext --inplace && \ - VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ - pip install --no-cache-dir dist/*.whl && \ + VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ + pip install dist/*.whl && \ rm -rf dist -ENTRYPOINT ["/usr/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file +# setup non-root user for vllm +RUN umask 002 && \ + useradd --uid 2000 --gid 0 vllm && \ + mkdir -p /home/vllm && \ + chmod g+rwx /home/vllm /usr/src /workspace + +COPY LICENSE /licenses/vllm.md +COPY examples/*.jinja /app/data/template/ + +USER 2000 +WORKDIR /home/vllm + +# Set the default entrypoint +ENTRYPOINT ["/opt/venv/vllm/bin/python", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 3f633a0ace0..930610401e7 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -4,7 +4,7 @@ # Dependencies for CPUs torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" and platform_machine != "s390x" torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" -torch==2.6.0.dev20241212+cpu; platform_machine == "s390x" +torch==2.6.0.dev20250104+cpu; platform_machine == "s390x" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" From 222b54c53174ea889abdad563d36346ae3d34930 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Tue, 11 Feb 2025 21:18:11 +0530 Subject: [PATCH 06/13] incorporate review comments Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 59 ++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index 6f8abdac672..df80a65bd5f 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -1,10 +1,13 @@ # Base UBI image for s390x architecture ARG BASE_UBI_IMAGE_TAG=9.5-1736404155 +ARG PYTHON_VERSION=3.12 FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} as base # Install basic dependencies +ARG PYTHON_VERSION +ENV PYTHON_VERSION=${PYTHON_VERSION} RUN microdnf -y update && microdnf install -y \ - python-pip python-wheel \ + python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel \ && microdnf clean all WORKDIR /workspace @@ -21,24 +24,28 @@ RUN microdnf install -y \ # Python Installation FROM base as python-install +ARG PYTHON_VERSION -RUN microdnf install -y python-devel && microdnf clean all - -# Set up Python virtual environment -RUN python -m venv /opt/venv/vllm -ENV PATH="/opt/venv/vllm/bin:$PATH" +ENV VIRTUAL_ENV=/opt/venv/vllm +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +ENV PYTHON_VERSION=${PYTHON_VERSION} +RUN microdnf install -y \ + python${PYTHON_VERSION}-devel && \ + python${PYTHON_VERSION} -m venv $VIRTUAL_ENV && pip install --no-cache -U pip wheel uv && microdnf clean all # Upgrade pip and install base tools -RUN python -m pip install --no-cache -U pip wheel uv cmake setuptools +RUN python -m pip install -U pip wheel uv cmake setuptools # Install Rust RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ . "$HOME/.cargo/env" +FROM python-install as pyarrow + # Build Apache Arrow -workdir /tmp +WORKDIR /tmp RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/uv \ git clone https://github.com/apache/arrow.git && \ cd arrow/cpp && \ mkdir release && cd release && \ @@ -76,12 +83,10 @@ RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz make && \ make install -# Clean up build files -RUN rm -rf /tmp/numactl-2.0.16 /tmp/v2.0.16.tar.gz - # Set include path ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" +FROM pyarrow as python-dependecies # Copy vLLM source COPY . /workspace/vllm WORKDIR /workspace/vllm @@ -93,8 +98,8 @@ RUN --mount=type=bind,source=.git,target=.git \ # Install dependencies, including PyTorch and Apache Arrow RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/ccache \ - pip install --no-cache-dir -v \ + --mount=type=cache,target=/root/.cache/uv \ + pip install -v \ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ /tmp/arrow/python/dist/*.whl \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ @@ -106,34 +111,28 @@ RUN rm -rf /tmp/arrow # Install torchvision ARG TORCH_VISION_VERSION=v0.20.1 WORKDIR /tmp -RUN git clone https://github.com/pytorch/vision.git && \ +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + git clone https://github.com/pytorch/vision.git && \ cd vision && \ git checkout $TORCH_VISION_VERSION && \ python setup.py bdist_wheel && \ - pip install --no-cache-dir dist/*.whl - -#Clean up build files for vision -RUN rm -rf /tmp/vision + uv pip install dist/*.whl # Final build stage -FROM python-install as vllm-cpu - -# Ensure we are using the virtual environment -ENV PATH="/opt/venv/vllm/bin:$PATH" +FROM python-dependecies as vllm-cpu +ARG PYTHON_VERSION # Set correct library path for torch and numactl -ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python3.9/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" - -# Upgrade setuptools for compatibility -RUN python -m pip install --upgrade setuptools +ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" WORKDIR /workspace/vllm # Build and install vllm RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/uv \ VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ - pip install dist/*.whl && \ + uv pip install dist/*.whl && \ rm -rf dist # setup non-root user for vllm @@ -149,4 +148,4 @@ USER 2000 WORKDIR /home/vllm # Set the default entrypoint -ENTRYPOINT ["/opt/venv/vllm/bin/python", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file +ENTRYPOINT ["/opt/vllm/bin/python", "-m", "vllm.entrypoints.openai.api_server"] From 02554db2ccee4503a9ab5fcc3b31f1f5f5409138 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Thu, 13 Feb 2025 11:33:55 +0530 Subject: [PATCH 07/13] correct virtual env path Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index df80a65bd5f..d929bdd64ba 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -26,7 +26,7 @@ RUN microdnf install -y \ FROM base as python-install ARG PYTHON_VERSION -ENV VIRTUAL_ENV=/opt/venv/vllm +ENV VIRTUAL_ENV=/opt/vllm ENV PATH="$VIRTUAL_ENV/bin:$PATH" ENV PYTHON_VERSION=${PYTHON_VERSION} RUN microdnf install -y \ From c57ceab3f6864828990a5d46e9cc65589220d1d3 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Fri, 21 Feb 2025 15:18:04 +0530 Subject: [PATCH 08/13] incorporating review comments Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index d929bdd64ba..1afabd91168 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -1,7 +1,7 @@ # Base UBI image for s390x architecture ARG BASE_UBI_IMAGE_TAG=9.5-1736404155 ARG PYTHON_VERSION=3.12 -FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} as base +FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base # Install basic dependencies ARG PYTHON_VERSION @@ -17,13 +17,17 @@ ENV LANG=C.UTF-8 \ # Install development utilities RUN microdnf install -y \ - which procps findutils tar vim git gcc g++ make patch make cython cargo zlib-devel \ + which procps findutils tar vim git gcc g++ make patch make cython zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel wget autoconf automake libtool && \ + openssl-devel openblas openblas-devel wget autoconf automake libtool cmake && \ microdnf clean all +# Install Rust +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ + . "$HOME/.cargo/env" + # Python Installation -FROM base as python-install +FROM base AS python-install ARG PYTHON_VERSION ENV VIRTUAL_ENV=/opt/vllm @@ -36,11 +40,7 @@ RUN microdnf install -y \ # Upgrade pip and install base tools RUN python -m pip install -U pip wheel uv cmake setuptools -# Install Rust -RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ - . "$HOME/.cargo/env" - -FROM python-install as pyarrow +FROM python-install AS pyarrow # Build Apache Arrow WORKDIR /tmp @@ -70,8 +70,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \ export PYARROW_PARALLEL=4 && \ export ARROW_BUILD_TYPE=release && \ python -m pip install -r requirements-build.txt && \ - python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel + python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel && \ + mkdir -p /tmp/arrow_wheels/ && \ + cp dist/*.whl /tmp/arrow_wheels/ +FROM python-install AS python-dependencies # Install numactl (needed for numa.h dependency) WORKDIR /tmp @@ -86,7 +89,6 @@ RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz # Set include path ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" -FROM pyarrow as python-dependecies # Copy vLLM source COPY . /workspace/vllm WORKDIR /workspace/vllm @@ -96,18 +98,19 @@ ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi +# Copy built PyArrow wheels from pyarrow stage +COPY --from=pyarrow /tmp/arrow_wheels/*.whl /tmp/arrow_wheels/ + +ENV PATH="/root/.cargo/bin:$PATH" + # Install dependencies, including PyTorch and Apache Arrow RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/uv \ pip install -v \ - 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ - /tmp/arrow/python/dist/*.whl \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 /tmp/arrow_wheels/*.whl \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ -r requirements-cpu.txt -#Clean up build files for arrow -RUN rm -rf /tmp/arrow - # Install torchvision ARG TORCH_VISION_VERSION=v0.20.1 WORKDIR /tmp @@ -120,20 +123,18 @@ RUN --mount=type=cache,target=/root/.cache/pip \ uv pip install dist/*.whl # Final build stage -FROM python-dependecies as vllm-cpu +FROM python-dependencies AS vllm-cpu ARG PYTHON_VERSION # Set correct library path for torch and numactl ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" - WORKDIR /workspace/vllm # Build and install vllm RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/uv \ VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ - uv pip install dist/*.whl && \ - rm -rf dist + uv pip install "$(echo dist/*.whl)[tensorizer]" # setup non-root user for vllm RUN umask 002 && \ From 2dcb4a9bee49788b9df2a8315816462245d7629c Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Wed, 26 Feb 2025 20:56:02 +0530 Subject: [PATCH 09/13] refactor: optimize the Docker layers Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index 1afabd91168..1a8d313619c 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -22,10 +22,6 @@ RUN microdnf install -y \ openssl-devel openblas openblas-devel wget autoconf automake libtool cmake && \ microdnf clean all -# Install Rust -RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ - . "$HOME/.cargo/env" - # Python Installation FROM base AS python-install ARG PYTHON_VERSION @@ -70,12 +66,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ export PYARROW_PARALLEL=4 && \ export ARROW_BUILD_TYPE=release && \ python -m pip install -r requirements-build.txt && \ - python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel && \ - mkdir -p /tmp/arrow_wheels/ && \ - cp dist/*.whl /tmp/arrow_wheels/ - -FROM python-install AS python-dependencies + python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel +FROM python-install AS numa-build # Install numactl (needed for numa.h dependency) WORKDIR /tmp RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz && \ @@ -89,6 +82,18 @@ RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz # Set include path ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" +FROM python-install AS rust +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ + . "$CARGO_HOME/env" && \ + rustup default stable && \ + rustup show + +FROM python-install AS python-dependencies + # Copy vLLM source COPY . /workspace/vllm WORKDIR /workspace/vllm @@ -98,16 +103,19 @@ ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi -# Copy built PyArrow wheels from pyarrow stage -COPY --from=pyarrow /tmp/arrow_wheels/*.whl /tmp/arrow_wheels/ - -ENV PATH="/root/.cargo/bin:$PATH" +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" # Install dependencies, including PyTorch and Apache Arrow RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ + WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ pip install -v \ - 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 /tmp/arrow_wheels/*.whl \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 $WHL_FILE \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ -r requirements-cpu.txt @@ -123,12 +131,20 @@ RUN --mount=type=cache,target=/root/.cache/pip \ uv pip install dist/*.whl # Final build stage -FROM python-dependencies AS vllm-cpu +FROM python-install AS vllm-cpu ARG PYTHON_VERSION # Set correct library path for torch and numactl ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" WORKDIR /workspace/vllm +ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" +ENV UV_LINK_MODE=copy + +COPY --from=numa-build /usr/local/lib /usr/local/lib +COPY --from=numa-build /usr/local/include /usr/local/include + +COPY --from=python-dependencies /opt/vllm /opt/vllm +COPY --from=python-dependencies /workspace/vllm /workspace/vllm # Build and install vllm RUN --mount=type=cache,target=/root/.cache/pip \ From d8bf9e2808fe3e90f6baf191247b4698f85103c3 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Mon, 3 Mar 2025 15:05:15 +0530 Subject: [PATCH 10/13] refactor: optimize the Docker layers Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 76 ++++++++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 44 deletions(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index 1a8d313619c..5df60e9d599 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -6,9 +6,6 @@ FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base # Install basic dependencies ARG PYTHON_VERSION ENV PYTHON_VERSION=${PYTHON_VERSION} -RUN microdnf -y update && microdnf install -y \ - python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel \ - && microdnf clean all WORKDIR /workspace @@ -17,7 +14,7 @@ ENV LANG=C.UTF-8 \ # Install development utilities RUN microdnf install -y \ - which procps findutils tar vim git gcc g++ make patch make cython zlib-devel \ + which procps findutils tar vim git gcc g++ make patch make zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ openssl-devel openblas openblas-devel wget autoconf automake libtool cmake && \ microdnf clean all @@ -30,12 +27,9 @@ ENV VIRTUAL_ENV=/opt/vllm ENV PATH="$VIRTUAL_ENV/bin:$PATH" ENV PYTHON_VERSION=${PYTHON_VERSION} RUN microdnf install -y \ - python${PYTHON_VERSION}-devel && \ + python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel && \ python${PYTHON_VERSION} -m venv $VIRTUAL_ENV && pip install --no-cache -U pip wheel uv && microdnf clean all -# Upgrade pip and install base tools -RUN python -m pip install -U pip wheel uv cmake setuptools - FROM python-install AS pyarrow # Build Apache Arrow @@ -76,8 +70,7 @@ RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz cd numactl-2.0.16 && \ ./autogen.sh && \ ./configure && \ - make && \ - make install + make # Set include path ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" @@ -92,34 +85,9 @@ RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ rustup default stable && \ rustup show -FROM python-install AS python-dependencies - -# Copy vLLM source -COPY . /workspace/vllm -WORKDIR /workspace/vllm - -# Check git repository integrity if enabled -ARG GIT_REPO_CHECK=0 -RUN --mount=type=bind,source=.git,target=.git \ - if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi - -ENV CARGO_HOME=/root/.cargo -ENV RUSTUP_HOME=/root/.rustup -ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" - -# Install dependencies, including PyTorch and Apache Arrow -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ - --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ - --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ - WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ - pip install -v \ - 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 $WHL_FILE \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ - -r requirements-cpu.txt - +FROM python-install AS torch-vision # Install torchvision +ARG TORCH_VERSION=2.6.0.dev20250104+cpu ARG TORCH_VISION_VERSION=v0.20.1 WORKDIR /tmp RUN --mount=type=cache,target=/root/.cache/pip \ @@ -127,8 +95,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \ git clone https://github.com/pytorch/vision.git && \ cd vision && \ git checkout $TORCH_VISION_VERSION && \ - python setup.py bdist_wheel && \ - uv pip install dist/*.whl + pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ + python setup.py bdist_wheel # Final build stage FROM python-install AS vllm-cpu @@ -136,15 +104,35 @@ ARG PYTHON_VERSION # Set correct library path for torch and numactl ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH" -WORKDIR /workspace/vllm ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH" ENV UV_LINK_MODE=copy +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +COPY . /workspace/vllm +WORKDIR /workspace/vllm -COPY --from=numa-build /usr/local/lib /usr/local/lib -COPY --from=numa-build /usr/local/include /usr/local/include +RUN --mount=type=bind,from=numa-build,src=/tmp/numactl-2.0.16,target=/numactl \ + make -C /numactl install -COPY --from=python-dependencies /opt/vllm /opt/vllm -COPY --from=python-dependencies /workspace/vllm /workspace/vllm +# Install dependencies, including PyTorch and Apache Arrow +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ + --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ + sed -i '/^torch/d' requirements-build.txt && \ + ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ + VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ + uv pip install -v \ + $ARROW_WHL_FILE \ + $VISION_WHL_FILE \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --index-strategy unsafe-best-match \ + -r requirements-build.txt \ + -r requirements-cpu.txt # Build and install vllm RUN --mount=type=cache,target=/root/.cache/pip \ From 03c3c0999550fb5a77ae022e9c3c8197b3cbab40 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Thu, 6 Mar 2025 18:57:13 +0530 Subject: [PATCH 11/13] use only uv cache and remove pip cache dir Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index 5df60e9d599..fb9959d33e1 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -14,9 +14,9 @@ ENV LANG=C.UTF-8 \ # Install development utilities RUN microdnf install -y \ - which procps findutils tar vim git gcc g++ make patch make zlib-devel \ + which procps findutils tar vim git gcc g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel wget autoconf automake libtool cmake && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake && \ microdnf clean all # Python Installation @@ -34,8 +34,7 @@ FROM python-install AS pyarrow # Build Apache Arrow WORKDIR /tmp -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/uv \ +RUN --mount=type=cache,target=/root/.cache/uv \ git clone https://github.com/apache/arrow.git && \ cd arrow/cpp && \ mkdir release && cd release && \ @@ -59,7 +58,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ cd ../../python && \ export PYARROW_PARALLEL=4 && \ export ARROW_BUILD_TYPE=release && \ - python -m pip install -r requirements-build.txt && \ + uv pip install -r requirements-build.txt && \ python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel FROM python-install AS numa-build @@ -87,15 +86,14 @@ RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ FROM python-install AS torch-vision # Install torchvision -ARG TORCH_VERSION=2.6.0.dev20250104+cpu +ARG TORCH_VERSION=2.7.0.dev20250304 ARG TORCH_VISION_VERSION=v0.20.1 WORKDIR /tmp -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/uv \ +RUN --mount=type=cache,target=/root/.cache/uv \ git clone https://github.com/pytorch/vision.git && \ cd vision && \ git checkout $TORCH_VISION_VERSION && \ - pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ + uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ python setup.py bdist_wheel # Final build stage @@ -117,8 +115,7 @@ RUN --mount=type=bind,from=numa-build,src=/tmp/numactl-2.0.16,target=/numactl \ make -C /numactl install # Install dependencies, including PyTorch and Apache Arrow -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/uv \ +RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ @@ -135,8 +132,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ -r requirements-cpu.txt # Build and install vllm -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/uv \ +RUN --mount=type=cache,target=/root/.cache/uv \ VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ uv pip install "$(echo dist/*.whl)[tensorizer]" @@ -144,7 +140,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ RUN umask 002 && \ useradd --uid 2000 --gid 0 vllm && \ mkdir -p /home/vllm && \ - chmod g+rwx /home/vllm /usr/src /workspace + chmod g+rwx /home/vllm COPY LICENSE /licenses/vllm.md COPY examples/*.jinja /app/data/template/ @@ -153,4 +149,4 @@ USER 2000 WORKDIR /home/vllm # Set the default entrypoint -ENTRYPOINT ["/opt/vllm/bin/python", "-m", "vllm.entrypoints.openai.api_server"] +ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file From 272cf03b39505e8b9bb752247575fa3e6e0f6a61 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Thu, 6 Mar 2025 19:57:05 +0530 Subject: [PATCH 12/13] chore: upgrade pytorch Signed-off-by: Dilip Gowda Bhagavan --- requirements-cpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 930610401e7..9491e27d127 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -4,7 +4,7 @@ # Dependencies for CPUs torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin" and platform_machine != "s390x" torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" -torch==2.6.0.dev20250104+cpu; platform_machine == "s390x" +torch==2.7.0.dev20250304; platform_machine == "s390x" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" From 3ce44c7c637c310b495a19923404256db58bb902 Mon Sep 17 00:00:00 2001 From: Dilip Gowda Bhagavan Date: Thu, 6 Mar 2025 21:18:10 +0530 Subject: [PATCH 13/13] fix: needs gcc-fortran to install scipy Signed-off-by: Dilip Gowda Bhagavan --- Dockerfile.s390x | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.s390x b/Dockerfile.s390x index fb9959d33e1..b499d4cb21d 100644 --- a/Dockerfile.s390x +++ b/Dockerfile.s390x @@ -14,7 +14,7 @@ ENV LANG=C.UTF-8 \ # Install development utilities RUN microdnf install -y \ - which procps findutils tar vim git gcc g++ make patch zlib-devel \ + which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ openssl-devel openblas openblas-devel autoconf automake libtool cmake && \ microdnf clean all