first add
This commit is contained in:
6
examples/01_basic/conf/config.yaml
Normal file
6
examples/01_basic/conf/config.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
# 基础配置示例
|
||||
learning_rate: 0.001
|
||||
batch_size: 32
|
||||
epochs: 100
|
||||
model_name: "ResNet50"
|
||||
device: "cuda"
|
||||
27
examples/01_basic/train.py
Normal file
27
examples/01_basic/train.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""基础Hydra配置管理示例"""
|
||||
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
|
||||
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
console.print("\n[bold green]🚀 基础Hydra配置示例[/bold green]\n")
|
||||
|
||||
table = Table(title="当前配置")
|
||||
table.add_column("参数", style="cyan")
|
||||
table.add_column("值", style="yellow")
|
||||
|
||||
table.add_row("学习率", str(cfg.learning_rate))
|
||||
table.add_row("批次大小", str(cfg.batch_size))
|
||||
table.add_row("训练轮数", str(cfg.epochs))
|
||||
table.add_row("模型名称", cfg.model_name)
|
||||
|
||||
console.print(table)
|
||||
console.print("\n[bold green]✅ 训练完成![/bold green]\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user