PolyCoderでコード生成を試してみたメモ
PolyCoderというプログラミング言語の大規模言語モデルを動かして、
コード生成を試してみたので、その手順メモを残しておきます。
Large Models of Source Code | GitHub
https://github.com/VHellendoorn/Code-LMs
PolyCoderが何なのかは、次の記事を見てもらえば、なんとなく分かると思います。
新たなオープンソースのAIコード生成モデル「PolyCoder」--カーネギーメロン大 | ZDNET
https://japan.zdnet.com/article/35184558/
試した環境:
- Ubuntu Linux 20.04 / WSL2
- anaconda3-5.3.1
仮想環境の作成とライブラリのインストール
conda環境を作成して、Activateします。
conda create --name transformers python=3.10
conda activate transformers
conda環境に、pytorchをインストールします。
PyTorchは、以下のサイトから環境にあわせたコマンドを選びます。
https://pytorch.org/get-started/locally/
以下のコマンドはLinux/CONDA/CPU(non-GPU)の場合です。
conda install pytorch torchvision torchaudio cpuonly -c pytorch
conda環境に、transformersをインストールします。
conda install transformers
コード生成を動かしてみる
コード生成は、READMEにサンプルコードが記載されています。
ここでは以下のようなコードを書きました。
generate.py
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from packaging import version
assert version.parse(transformers.__version__) >= version.parse("4.23.0")
tokenizer = AutoTokenizer.from_pretrained("NinedayWang/PolyCoder-2.7B")
model = AutoModelForCausalLM.from_pretrained("NinedayWang/PolyCoder-2.7B")
prompt = '''def binarySearch(arr, left, right, x):
mid = (left +'''
input_ids = tokenizer.encode(prompt, return_tensors='pt')
result = model.generate(input_ids, max_length=50, num_beams=4, num_return_sequences=4)
for idx, res in enumerate(result):
print(f"---- candidate #{idx + 1}")
print(tokenizer.decode(res))
input_idsには、promptをtokenizeしたリスト
resultには、生成したコードのtokenリストが入っています。
実行結果は、次のとおりです。
promptで指定した、以降のコードを生成しています。
$ python generate.py
---- candidate #1
def binarySearch(arr, left, right, x):
mid = (left + right) // 2
if arr[mid] == x:
return mid
---- candidate #2
def binarySearch(arr, left, right, x):
mid = (left + right) // 2
if arr[mid] > x:
return binarySearch(arr
---- candidate #3
def binarySearch(arr, left, right, x):
mid = (left + right) // 2
if arr[mid] < x:
return -1
---- candidate #4
def binarySearch(arr, left, right, x):
mid = (left + right) / 2
if arr[mid] > x:
return binarySearch(arr
生成したコードのスコアを見てみる
以下を参考にして、生成したコードのスコア(Probability)も見てみます。
Announcement Generation: Get probabilities for generated output | discuss.haggingface https://discuss.huggingface.co/t/announcement-generation-get-probabilities-for-generated-output/30075
ここでは以下のようなコードを書きました。
generate_with_score.py
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from packaging import version
assert version.parse(transformers.__version__) >= version.parse("4.23.0")
tokenizer = AutoTokenizer.from_pretrained("NinedayWang/PolyCoder-2.7B")
model = AutoModelForCausalLM.from_pretrained("NinedayWang/PolyCoder-2.7B")
prompt = '''def binarySearch(arr, left, right, x):
mid = (left +'''
input_ids = tokenizer.encode(prompt, return_tensors='pt')
result = model.generate(input_ids, max_length=50, num_beams=4, num_return_sequences=4, return_dict_in_generate=True, output_scores=True)
transition_scores = model.compute_transition_beam_scores(
result.sequences, result.scores, result.beam_indices
)
input_length = input_ids.shape[1]
generated_tokens = result.sequences[:, input_length:]
for idx in range(len(result)):
print(f"---- candidate #{idx + 1}")
for tok, score in zip(generated_tokens[idx], transition_scores[idx]):
# | token | token string | logits | probability
print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.4f} | {np.exp(score.numpy()):.2%}")
generated_tokensには、生成コードtokenリスト(元prompt部分を除いたもの)
transition_scoresには、各tokenのスコアが入っています。
実行結果は、次のとおりです。
候補によって揺れている等号・不等号の箇所は低スコアなのがわかります。
$ python generate_with_score.py
---- candidate #1
| 2052 | right | -0.0034 | 99.66%
| 11 | ) | -0.0416 | 95.93%
| 322 | // | -0.7376 | 47.83%
| 444 | 2 | -0.0043 | 99.57%
| 188 |
| -0.0281 | 97.23%
| 209 | | -0.0443 | 95.66%
| 209 | | -0.0002 | 99.98%
| 209 | | -0.0000 | 100.00%
| 392 | if | -0.6518 | 52.11%
| 7793 | arr | -0.6220 | 53.69%
| 61 | [ | -0.0271 | 97.33%
| 8894 | mid | -0.0313 | 96.92%
| 63 | ] | -0.0378 | 96.29%
| 489 | == | -1.5390 | 21.46%
| 754 | x | -0.0220 | 97.82%
| 28 | : | -0.0254 | 97.49%
| 188 |
| -0.0372 | 96.34%
| 209 | | -0.0007 | 99.93%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0001 | 99.99%
| 209 | | -0.0004 | 99.96%
| 209 | | -0.0049 | 99.51%
| 209 | | -0.0012 | 99.88%
| 429 | return | -0.1889 | 82.78%
| 12113 | mid | -0.1280 | 87.99%
| 188 |
| -0.1522 | 85.88%
| 209 | | -0.0509 | 95.04%
| 209 | | -0.0001 | 99.99%
---- candidate #2
| 2052 | right | -0.0034 | 99.66%
| 11 | ) | -0.0416 | 95.93%
| 322 | // | -0.7376 | 47.83%
| 444 | 2 | -0.0043 | 99.57%
| 188 |
| -0.0281 | 97.23%
| 209 | | -0.0443 | 95.66%
| 209 | | -0.0002 | 99.98%
| 209 | | -0.0000 | 100.00%
| 392 | if | -0.6518 | 52.11%
| 7793 | arr | -0.6220 | 53.69%
| 61 | [ | -0.0271 | 97.33%
| 8894 | mid | -0.0313 | 96.92%
| 63 | ] | -0.0378 | 96.29%
| 609 | > | -1.0378 | 35.42%
| 754 | x | -0.0059 | 99.41%
| 28 | : | -0.0159 | 98.42%
| 188 |
| -0.0474 | 95.37%
| 209 | | -0.0005 | 99.95%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0001 | 99.99%
| 209 | | -0.0006 | 99.94%
| 209 | | -0.0055 | 99.45%
| 209 | | -0.0010 | 99.90%
| 429 | return | -0.4343 | 64.77%
| 4012 | binary | -0.9906 | 37.14%
| 4648 | Search | -0.0017 | 99.83%
| 10 | ( | -0.0097 | 99.03%
| 821 | arr | -0.0165 | 98.36%
---- candidate #3
| 2052 | right | -0.0034 | 99.66%
| 11 | ) | -0.0416 | 95.93%
| 322 | // | -0.7376 | 47.83%
| 444 | 2 | -0.0043 | 99.57%
| 188 |
| -0.0281 | 97.23%
| 209 | | -0.0443 | 95.66%
| 209 | | -0.0002 | 99.98%
| 209 | | -0.0000 | 100.00%
| 392 | if | -0.6518 | 52.11%
| 7793 | arr | -0.6220 | 53.69%
| 61 | [ | -0.0271 | 97.33%
| 8894 | mid | -0.0313 | 96.92%
| 63 | ] | -0.0378 | 96.29%
| 360 | < | -1.2713 | 28.05%
| 754 | x | -0.0211 | 97.91%
| 28 | : | -0.0229 | 97.73%
| 188 |
| -0.0492 | 95.20%
| 209 | | -0.0006 | 99.94%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0001 | 99.99%
| 209 | | -0.0004 | 99.96%
| 209 | | -0.0059 | 99.41%
| 209 | | -0.0012 | 99.88%
| 429 | return | -0.6510 | 52.15%
| 418 | - | -0.8807 | 41.45%
| 19 | 1 | -0.0478 | 95.33%
| 188 |
| -0.0562 | 94.54%
| 209 | | -0.0077 | 99.23%
---- candidate #4
| 2052 | right | -0.0034 | 99.66%
| 11 | ) | -0.0416 | 95.93%
| 348 | / | -1.1086 | 33.00%
| 444 | 2 | -0.0054 | 99.46%
| 188 |
| -0.0724 | 93.02%
| 209 | | -0.0398 | 96.10%
| 209 | | -0.0003 | 99.97%
| 209 | | -0.0000 | 100.00%
| 392 | if | -0.7080 | 49.26%
| 7793 | arr | -0.7701 | 46.30%
| 61 | [ | -0.0388 | 96.19%
| 8894 | mid | -0.0600 | 94.18%
| 63 | ] | -0.0414 | 95.94%
| 609 | > | -1.0747 | 34.14%
| 754 | x | -0.0057 | 99.43%
| 28 | : | -0.0158 | 98.44%
| 188 |
| -0.0439 | 95.71%
| 209 | | -0.0006 | 99.94%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0000 | 100.00%
| 209 | | -0.0001 | 99.99%
| 209 | | -0.0006 | 99.94%
| 209 | | -0.0047 | 99.53%
| 209 | | -0.0012 | 99.88%
| 429 | return | -0.4082 | 66.48%
| 4012 | binary | -0.7974 | 45.05%
| 4648 | Search | -0.0019 | 99.81%
| 10 | ( | -0.0094 | 99.06%
| 821 | arr | -0.0189 | 98.13%
以上。