Фундамент: Выбор правильного дистрибутива и управление зависимостями. Первая ошибка — установка через `pip install jax` без указания версий. Это приведет к потенциально несовместимому набору зависимостей (jaxlib, numpy, и т.д.). Профессионалы используют строгое управление окружением.
Шаг 1: Создайте изолированное окружение. Используйте `conda` (предпочтительно для управления нативными библиотеками, особенно под Linux) или `uv`/`poetry`/`venv` с фиксацией версий. Для conda: `conda create -n jax-env python=3.10`. Активируйте его.
Шаг 2: Установка под CPU (базовый вариант). Это самый простой путь, но он ограничивает возможности. Используйте официальные wheels от Google. Для Linux/Mac: `pip install --upgrade "jax[cpu]"`. Для Windows (через WSL2 или напрямую, с ограничениями): также `jax[cpu]`. Эта команда установит совместимые версии `jax` и `jaxlib`. Убедитесь, что у вас установлен современный компилятор (например, gcc).
Шаг 3: Установка под NVIDIA GPU — основной сценарий. Здесь есть два пути: использование предсобранных wheels или компиляция из исходников для максимального контроля.
Вариант А (Рекомендуемый для большинства): Установка с поддержкой CUDA и cuDNN. Зайдите на официальную страницу совместимости JAX (GitHub). Найдите таблицу, связывающую версию `jaxlib`, версию CUDA и версию cuDNN. Например, `jaxlib==0.4.20+cuda11.cudnn86`. Установите драйверы NVIDIA, CUDA Toolkit и cuDNN в систему, следуя инструкциям NVIDIA. Версия драйвера должна поддерживать выбранную версию CUDA. Затем установите JAX командой вида:
`pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`
Обратите внимание на суффикс (`cuda11`). Это установит совместимый `jaxlib`. Проверьте установку: `python -c "import jax; print(jax.devices())"`. Должен появиться список ваших GPU.
Вариант Б (Для энтузиастов и специфичных окружений): Компиляция из исходников. Клонируйте репозитории `jax` и `jaxlib`. Установите Bazel (систему сборки). Отредактируйте конфигурационные файлы `./build/` в `jaxlib`, чтобы точно указать пути к CUDA, cuDNN, версию compute capability ваших GPU (например, `sm_86` для RTX 30xx). Сборка займет значительное время, но даст вам бинарник, оптимизированный именно под вашу архитектуру. Это оправдано для продакшн-кластера.
Шаг 4: Установка под AMD GPU (ROCm). Поддержка стабильно улучшается. Требуется система с установленным ROCm stack (драйверы, HIP, ROCm). Установка производится через `pip` с указанием конкретного репозитория:
`pip install --upgrade "jax[rocm]" -f https://storage.googleapis.com/jax-releases/jax_rocm_releases.html`
Проверьте поддержку вашей конкретной карты (например, MI series поддерживаются хорошо). Может потребоваться настройка переменных окружения, таких как `HSA_OVERRIDE_GFX_VERSION`.
Шаг 5: Установка для работы с Google Cloud TPU. Это отдельный мир. Самый простой способ — использовать предварительно настроенные образы в Google Cloud (Deep Learning VM) или Google Colab. Для локальной разработки с эмуляцией TPU можно установить `jax[tpu]` через pip, но для реального использования требуется доступ к облачным TPU и установка `libtpu`. Процесс включает аутентификацию в GCP, создание TPU-виртуальной машины и установку JAX с указанием пути к `libtpu`.
Продвинутая настройка и оптимизация:
- **Воспроизводимость и детерминизм:** Установите `jax.config.update('jax_default_matmul_precision', 'tensorfloat32')` для баланса скорости и точности на Ampere+ GPU. Для полного детерминизма (ценой производительности) используйте `jax.config.update('jax_disable_jit', False)` (JIT остается, но семантика меняется) и установите seed для генераторов случайных чисел.
- **Оптимизация использования памяти:** Изучите и используйте `jax.jit` с аргументами `donate_argnums` для переиспользования памяти буферов. Включите поддержку `memory profiler` для отслеживания утечек. Для больших моделей используйте стратегию шардинга данных (`jax.pmap` или новую `jax.experimental.maps`/`shard_map`).
- **Интеграция с экосистемой:** Установите совместимые библиотеки. `Flax` или `Haiku` для нейросетевых слоев, `Optax` для оптимизаторов, `Chex` для тестирования. Убедитесь, что их версии совместимы с вашей версией JAX. Используйте `pip` с constraints файлом или `conda` для согласования.
- **Профилирование и отладка:** Установите `perfetto` или используйте встроенный профилировщик JAX (`jax.profiler`). Настройте логирование для отслеживания компиляции JIT (это может занимать время при первом запуске).
Комментарии (12)