# 使用支持CUDA的PyTorch基础镜像 FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime # 设置工作目录 WORKDIR /app # 安装系统依赖 RUN apt-get update && apt-get install -y \ build-essential \ git \ && rm -rf /var/lib/apt/lists/* # 复制requirements.pytorch.txt COPY requirements.pytorch.txt . # 安装Python依赖 RUN pip install --no-cache-dir -r requirements.pytorch.txt # 安装额外的机器学习包 RUN pip install --no-cache-dir \ scikit-learn \ pandas \ matplotlib \ tensorboard # 复制项目文件 COPY . . # 设置环境变量 ENV PYTHONUNBUFFERED=1 # 设置容器启动命令 CMD ["python", "train.py"]