PolyCoderというプログラミング言語の大規模言語モデルを動かして、
コード生成を試してみたので、その手順メモを残しておきます。

Large Models of Source Code | GitHub
https://github.com/VHellendoorn/Code-LMs

PolyCoderが何なのかは、次の記事を見てもらえば、なんとなく分かると思います。

新たなオープンソースのAIコード生成モデル「PolyCoder」–カーネギーメロン大 | ZDNET
https://japan.zdnet.com/article/35184558/

試した環境:

  • Ubuntu Linux 20.04 / WSL2
  • anaconda3-5.3.1

仮想環境の作成とライブラリのインストール

conda環境を作成して、Activateします。

conda create --name transformers python=3.10
conda activate transformers

conda環境に、pytorchをインストールします。

PyTorchは、以下のサイトから環境にあわせたコマンドを選びます。
https://pytorch.org/get-started/locally/
以下のコマンドはLinux/CONDA/CPU(non-GPU)の場合です。

conda install pytorch torchvision torchaudio cpuonly -c pytorch

conda環境に、transformersをインストールします。

conda install transformers

コード生成を動かしてみる

コード生成は、READMEにサンプルコードが記載されています。

ここでは以下のようなコードを書きました。

generate.py

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

from packaging import version
assert version.parse(transformers.__version__) >= version.parse("4.23.0")

tokenizer = AutoTokenizer.from_pretrained("NinedayWang/PolyCoder-2.7B")
model = AutoModelForCausalLM.from_pretrained("NinedayWang/PolyCoder-2.7B")

prompt = '''def binarySearch(arr, left, right, x):
    mid = (left +'''
input_ids = tokenizer.encode(prompt, return_tensors='pt')
result = model.generate(input_ids, max_length=50, num_beams=4, num_return_sequences=4)

for idx, res in enumerate(result):
    print(f"---- candidate #{idx + 1}")
    print(tokenizer.decode(res))

input_idsには、promptをtokenizeしたリスト
resultには、生成したコードのtokenリストが入っています。

実行結果は、次のとおりです。
promptで指定した、以降のコードを生成しています。

$ python generate.py
---- candidate #1
def binarySearch(arr, left, right, x):
    mid = (left + right) // 2
    if arr[mid] == x:
        return mid

---- candidate #2
def binarySearch(arr, left, right, x):
    mid = (left + right) // 2
    if arr[mid] > x:
        return binarySearch(arr
---- candidate #3
def binarySearch(arr, left, right, x):
    mid = (left + right) // 2
    if arr[mid] < x:
        return -1

---- candidate #4
def binarySearch(arr, left, right, x):
    mid = (left + right) / 2
    if arr[mid] > x:
        return binarySearch(arr

生成したコードのスコアを見てみる

以下を参考にして、生成したコードのスコア(Probability)も見てみます。

Announcement Generation: Get probabilities for generated output | discuss.haggingface https://discuss.huggingface.co/t/announcement-generation-get-probabilities-for-generated-output/30075

ここでは以下のようなコードを書きました。

generate_with_score.py

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np

from packaging import version
assert version.parse(transformers.__version__) >= version.parse("4.23.0")

tokenizer = AutoTokenizer.from_pretrained("NinedayWang/PolyCoder-2.7B")
model = AutoModelForCausalLM.from_pretrained("NinedayWang/PolyCoder-2.7B")

prompt = '''def binarySearch(arr, left, right, x):
    mid = (left +'''
input_ids = tokenizer.encode(prompt, return_tensors='pt')
result = model.generate(input_ids, max_length=50, num_beams=4, num_return_sequences=4, return_dict_in_generate=True, output_scores=True)
transition_scores = model.compute_transition_beam_scores(
    result.sequences, result.scores, result.beam_indices
)
input_length = input_ids.shape[1]
generated_tokens = result.sequences[:, input_length:]

for idx in range(len(result)):
    print(f"---- candidate #{idx + 1}")
    for tok, score in zip(generated_tokens[idx], transition_scores[idx]):
        # | token | token string | logits | probability
        print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.4f} | {np.exp(score.numpy()):.2%}")

generated_tokensには、生成コードtokenリスト(元prompt部分を除いたもの)
transition_scoresには、各tokenのスコアが入っています。

実行結果は、次のとおりです。
候補によって揺れている等号・不等号の箇所は低スコアなのがわかります。

$ python generate_with_score.py
---- candidate #1
|  2052 |  right   | -0.0034 | 99.66%
|    11 | )        | -0.0416 | 95.93%
|   322 |  //      | -0.7376 | 47.83%
|   444 |  2       | -0.0043 | 99.57%
|   188 |
        | -0.0281 | 97.23%
|   209 |          | -0.0443 | 95.66%
|   209 |          | -0.0002 | 99.98%
|   209 |          | -0.0000 | 100.00%
|   392 |  if      | -0.6518 | 52.11%
|  7793 |  arr     | -0.6220 | 53.69%
|    61 | [        | -0.0271 | 97.33%
|  8894 | mid      | -0.0313 | 96.92%
|    63 | ]        | -0.0378 | 96.29%
|   489 |  ==      | -1.5390 | 21.46%
|   754 |  x       | -0.0220 | 97.82%
|    28 | :        | -0.0254 | 97.49%
|   188 |
        | -0.0372 | 96.34%
|   209 |          | -0.0007 | 99.93%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0001 | 99.99%
|   209 |          | -0.0004 | 99.96%
|   209 |          | -0.0049 | 99.51%
|   209 |          | -0.0012 | 99.88%
|   429 |  return  | -0.1889 | 82.78%
| 12113 |  mid     | -0.1280 | 87.99%
|   188 |
        | -0.1522 | 85.88%
|   209 |          | -0.0509 | 95.04%
|   209 |          | -0.0001 | 99.99%
---- candidate #2
|  2052 |  right   | -0.0034 | 99.66%
|    11 | )        | -0.0416 | 95.93%
|   322 |  //      | -0.7376 | 47.83%
|   444 |  2       | -0.0043 | 99.57%
|   188 |
        | -0.0281 | 97.23%
|   209 |          | -0.0443 | 95.66%
|   209 |          | -0.0002 | 99.98%
|   209 |          | -0.0000 | 100.00%
|   392 |  if      | -0.6518 | 52.11%
|  7793 |  arr     | -0.6220 | 53.69%
|    61 | [        | -0.0271 | 97.33%
|  8894 | mid      | -0.0313 | 96.92%
|    63 | ]        | -0.0378 | 96.29%
|   609 |  >       | -1.0378 | 35.42%
|   754 |  x       | -0.0059 | 99.41%
|    28 | :        | -0.0159 | 98.42%
|   188 |
        | -0.0474 | 95.37%
|   209 |          | -0.0005 | 99.95%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0001 | 99.99%
|   209 |          | -0.0006 | 99.94%
|   209 |          | -0.0055 | 99.45%
|   209 |          | -0.0010 | 99.90%
|   429 |  return  | -0.4343 | 64.77%
|  4012 |  binary  | -0.9906 | 37.14%
|  4648 | Search   | -0.0017 | 99.83%
|    10 | (        | -0.0097 | 99.03%
|   821 | arr      | -0.0165 | 98.36%
---- candidate #3
|  2052 |  right   | -0.0034 | 99.66%
|    11 | )        | -0.0416 | 95.93%
|   322 |  //      | -0.7376 | 47.83%
|   444 |  2       | -0.0043 | 99.57%
|   188 |
        | -0.0281 | 97.23%
|   209 |          | -0.0443 | 95.66%
|   209 |          | -0.0002 | 99.98%
|   209 |          | -0.0000 | 100.00%
|   392 |  if      | -0.6518 | 52.11%
|  7793 |  arr     | -0.6220 | 53.69%
|    61 | [        | -0.0271 | 97.33%
|  8894 | mid      | -0.0313 | 96.92%
|    63 | ]        | -0.0378 | 96.29%
|   360 |  <       | -1.2713 | 28.05%
|   754 |  x       | -0.0211 | 97.91%
|    28 | :        | -0.0229 | 97.73%
|   188 |
        | -0.0492 | 95.20%
|   209 |          | -0.0006 | 99.94%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0001 | 99.99%
|   209 |          | -0.0004 | 99.96%
|   209 |          | -0.0059 | 99.41%
|   209 |          | -0.0012 | 99.88%
|   429 |  return  | -0.6510 | 52.15%
|   418 |  -       | -0.8807 | 41.45%
|    19 | 1        | -0.0478 | 95.33%
|   188 |
        | -0.0562 | 94.54%
|   209 |          | -0.0077 | 99.23%
---- candidate #4
|  2052 |  right   | -0.0034 | 99.66%
|    11 | )        | -0.0416 | 95.93%
|   348 |  /       | -1.1086 | 33.00%
|   444 |  2       | -0.0054 | 99.46%
|   188 |
        | -0.0724 | 93.02%
|   209 |          | -0.0398 | 96.10%
|   209 |          | -0.0003 | 99.97%
|   209 |          | -0.0000 | 100.00%
|   392 |  if      | -0.7080 | 49.26%
|  7793 |  arr     | -0.7701 | 46.30%
|    61 | [        | -0.0388 | 96.19%
|  8894 | mid      | -0.0600 | 94.18%
|    63 | ]        | -0.0414 | 95.94%
|   609 |  >       | -1.0747 | 34.14%
|   754 |  x       | -0.0057 | 99.43%
|    28 | :        | -0.0158 | 98.44%
|   188 |
        | -0.0439 | 95.71%
|   209 |          | -0.0006 | 99.94%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0000 | 100.00%
|   209 |          | -0.0001 | 99.99%
|   209 |          | -0.0006 | 99.94%
|   209 |          | -0.0047 | 99.53%
|   209 |          | -0.0012 | 99.88%
|   429 |  return  | -0.4082 | 66.48%
|  4012 |  binary  | -0.7974 | 45.05%
|  4648 | Search   | -0.0019 | 99.81%
|    10 | (        | -0.0094 | 99.06%
|   821 | arr      | -0.0189 | 98.13%

以上。