torch.compile()

제민욱 - Feb 25 - - Dev Community

torch.compile이란?

torch >= 2.0이후에 pytorch code의 speed를 up하기 위해서, JIT-compiling을 통해서 pytorch 코드를 optimized된 nvidia kernel로 변경해주는 기법입니다.

  • 실제 컴파일된 Binary를 만드는 것이 아닌, JIT로 runtime에 bytecode가 컴파일되고 캐시되어 (~/.cache/torch) 사용됩니다.

위의 특성에 의해서 compile()코드는 처음 또는 최초의 몇번은 오히려 eager에 비해서 더 느린 속도를 낼 수 있다.

실제 컴파일과 관련된 기능은 torch.export()으로, 이는 AOT compile로 .cubin(GPU), .so(C++)로 컴파일 할 수 있다.

Components(2)

Image description

Image above

내부적으로 torch.compile()은 크게 (2)요소로 구성된다.

  1. Torch Dynamo
  2. Torch Inductor

Good to know

Triton은 pytorch 2.0 binary에 내부적으로 포함되어있습니다. 그러나 혹시 문제가 있을 경우 pip install torchtriton이 필요합니다. 간단하게 찾아본 바에 의하면 Triton은 CUDA와 비교해서 CUDA가 저수준이면 Triton은 고수준이고, Nvidia외에도 AMD등의 GPU 아키텍처에서 실행가능합니다.

torch.complie()은 내부적으로 python의 오버헤드를 줄이며, 또한 cuda-graph를 토대로 효율적으로 CPU-GPU간의 통신 오버헤드를 줄입니다.

Usage

torch.compile은 크게 3곳에 사용가능합니다.

  1. function
  2. nn.module
  3. nested (function and module)
. . . . . .