Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sourced/ml/core/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# flake8: noqa
from sourced.ml.core.algorithms.tf_idf import log_tf_log_idf
from sourced.ml.core.algorithms.uast_ids_to_bag import UastIds2Bag, uast2sequence
from sourced.ml.core.algorithms.uast_struct_to_bag import UastRandomWalk2Bag, UastSeq2Bag
from sourced.ml.core.algorithms.uast_inttypes_to_nodes import Uast2QuantizedChildren
from sourced.ml.core.algorithms.uast_inttypes_to_graphlets import Uast2GraphletBag
from sourced.ml.core.algorithms.uast_to_role_id_pairs import Uast2RoleIdPairs
from sourced.ml.core.algorithms.uast_id_distance import Uast2IdLineDistance, Uast2IdTreeDistance
from sourced.ml.core.algorithms.uast_to_id_sequence import Uast2IdSequence
42 changes: 42 additions & 0 deletions sourced/ml/core/algorithms/id_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy


def extract_coocc_matrix(global_shape, word_indices, model):
# Stage 1 - extract the tokens, map them to the global vocabulary
indices = []
mapped_indices = []
for i, w in enumerate(model.tokens):
gi = word_indices.get(w)
if gi is not None:
indices.append(i)
mapped_indices.append(gi)
indices = numpy.array(indices)
mapped_indices = numpy.array(mapped_indices)
# Stage 2 - sort the matched tokens by the index in the vocabulary
order = numpy.argsort(mapped_indices)
indices = indices[order]
mapped_indices = mapped_indices[order]
# Stage 3 - produce the csr_matrix with the matched tokens **only**
matrix = model.matrix.tocsr()[indices][:, indices]
# Stage 4 - convert this matrix to the global (ccmatrix) coordinates
csr_indices = matrix.indices
for i, v in enumerate(csr_indices):
# Here we use the fact that indices and mapped_indices are in the same order
csr_indices[i] = mapped_indices[v]
csr_indptr = matrix.indptr
new_indptr = [0]
for i, v in enumerate(mapped_indices):
prev_ptr = csr_indptr[i]
ptr = csr_indptr[i + 1]

# Handle missing rows
prev = (mapped_indices[i - 1] + 1) if i > 0 else 0
for _ in range(prev, v):
new_indptr.append(prev_ptr)

new_indptr.append(ptr)
for _ in range(mapped_indices[-1] + 1, global_shape[0]):
new_indptr.append(csr_indptr[-1])
matrix.indptr = numpy.array(new_indptr)
matrix._shape = global_shape
return matrix
128 changes: 128 additions & 0 deletions sourced/ml/core/algorithms/id_splitter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Neural Identifier Splitter
Article [Splitting source code identifiers using Bidirectional LSTM Recurrent Neural Network](https://arxiv.org/abs/1805.11651).

### Agenda
* Data
* Training pipeline
* How to launch

### Data
You can download the dataset [here](https://drive.google.com/open?id=1wZR5zF1GL1fVcA1gZuAN_9rSLd5ssqKV). More information about the dataset is available [here](https://github.com/src-d/datasets/tree/master/Identifiers).
#### Data format
* format of file: `.csv.gz`.
* the `csv` structure:

|num_files|num_occ|num_repos|token|token_split|
|:--|:--|:--|:--|:--|
|1|2|1|quesesSet|queses set|
|...|...|...|...|...|

#### Data stats
* 49 millions of identifiers
* 1 GB

### Training pipeline
Training pipeline consists of several steps
* [prepare features](https://github.com/src-d/ml/blob/master/sourced/ml/algorithms/id_splitter/features.py#L44-#L118) - read data, extract features, train/test split
* [prepare generators for keras](https://github.com/src-d/ml/blob/master/sourced/ml/cmd/train_id_split.py#L34-#L48)
* [prepare model - RNN or CNN](https://github.com/src-d/ml/blob/master/sourced/ml/cmd/train_id_split.py#L53-#L76)
* [training](https://github.com/src-d/ml/blob/master/sourced/ml/cmd/train_id_split.py#L78-#L89)
* [quality report and save the model](https://github.com/src-d/ml/blob/master/sourced/ml/cmd/train_id_split.py#L91-#L96)

### How to launch
First of all you need to download data using link above.

Usage:
```console
usage: srcml train-id-split [-h] -i INPUT [-e EPOCHS] [-b BATCH_SIZE]
[-l LENGTH] -o OUTPUT [-t TEST_RATIO]
[-p {pre,post}] [--optimizer {RMSprop,Adam}]
[--lr LR] [--final-lr FINAL_LR]
[--samples-before-report SAMPLES_BEFORE_REPORT]
[--val-batch-size VAL_BATCH_SIZE] [--seed SEED]
[--devices DEVICES]
[--csv-identifier CSV_IDENTIFIER]
[--csv-identifier-split CSV_IDENTIFIER_SPLIT]
[--include-csv-header] --model {RNN,CNN}
[-s STACK]
[--type-cell {GRU,LSTM,CuDNNLSTM,CuDNNGRU}]
[-n NEURONS] [-f FILTERS] [-k KERNEL_SIZES]
[--dim-reduction DIM_REDUCTION]

optional arguments:
-h, --help show this help message and exit
-i INPUT, --input INPUT
Path to the input data in CSV
format:num_files,num_occ,num_repos,token,token_split
-e EPOCHS, --epochs EPOCHS
Number of training epochs. The more the betterbut the
training time is proportional. (default: 10)
-b BATCH_SIZE, --batch-size BATCH_SIZE
Batch size. Higher values better utilize GPUsbut may
harm the convergence. (default: 500)
-l LENGTH, --length LENGTH
RNN sequence length. (default: 40)
-o OUTPUT, --output OUTPUT
Path to store the trained model.
-t TEST_RATIO, --test-ratio TEST_RATIO
Fraction of the dataset to use for evaluation.
(default: 0.2)
-p {pre,post}, --padding {pre,post}
Whether to pad before or after each sequence.
(default: post)
--optimizer {RMSprop,Adam}
Algorithm to use as an optimizer for the neural net.
(default: Adam)
--lr LR Initial learning rate. (default: 0.001)
--final-lr FINAL_LR Final learning rate. The decrease from the initial
learning rate is done linearly. (default: 1e-05)
--samples-before-report SAMPLES_BEFORE_REPORT
Number of samples between each validation reportand
training updates. (default: 5000000)
--val-batch-size VAL_BATCH_SIZE
Batch size for validation.It can be increased to speed
up the pipeline butit proportionally increases the
memory consumption. (default: 2000)
--seed SEED Random seed. (default: 1989)
--devices DEVICES Device(s) to use. '-1' means CPU. (default: 0)
--csv-identifier CSV_IDENTIFIER
Column name in the CSV file for the raw identifier.
(default: 3)
--csv-identifier-split CSV_IDENTIFIER_SPLIT
Column name in the CSV file for the splitidentifier.
(default: 4)
--include-csv-header Treat the first line of the input CSV as a
regularline. (default: False)
--model {RNN,CNN} Neural Network model to use to learn the
identifiersplitting task.
-s STACK, --stack STACK
Number of layers stacked on each other. (default: 2)
--type-cell {GRU,LSTM,CuDNNLSTM,CuDNNGRU}
Recurrent layer type to use. (default: LSTM)
-n NEURONS, --neurons NEURONS
Number of neurons on each layer. (default: 256)
-f FILTERS, --filters FILTERS
Number of filters for each kernel size. (default:
64,32,16,8)
-k KERNEL_SIZES, --kernel-sizes KERNEL_SIZES
Sizes for sliding windows. (default: 2,4,8,16)
--dim-reduction DIM_REDUCTION
Number of 1-d kernels to reduce dimensionalityafter
each layer. (default: 32)
```


Examples of commands:
1) Train RNN with LSTM cells
```console
srcml train-id-split --model RNN --input /path/to/input.csv.gz --output /path/to/output
```
2) Train RNN with CuDNNLSTM cells
```console
srcml train-id-split --model RNN --input /path/to/input.csv.gz --output /path/to/output \
--type-cell CuDNNLSTM
```
3) Train CNN
```console
srcml train-id-split --model CNN --input /path/to/input.csv.gz --output /path/to/output
```
Empty file.
118 changes: 118 additions & 0 deletions sourced/ml/core/algorithms/id_splitter/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import logging
import string
import tarfile
from typing import List, Tuple

from modelforge.progress_bar import progress_bar
import numpy


def read_identifiers(csv_path: str, use_header: bool, max_identifier_len: int, identifier_col: int,
split_identifier_col: int, shuffle: bool = True) -> List[str]:
"""
Reads and filters too long identifiers in the CSV file.

:param csv_path: path to the CSV file.
:param use_header: uses header as normal line (True) or treat as header line with column names.
:param max_identifier_len: maximum length of raw identifiers. Skip identifiers that are longer.
:param identifier_col: column name in the CSV file for the raw identifier.
:param split_identifier_col: column name in the CSV file for the split identifier lowercase.
:param shuffle: indicates whether to reorder the list of identifiers
at random after reading it.
:return: list of split identifiers.
"""
log = logging.getLogger("read_identifiers")
log.info("Reading data from the CSV file %s", csv_path)
identifiers = []
# TODO: Update dataset loading as soon as https://github.com/src-d/backlog/issues/1212 done
# Think about dataset download step
with tarfile.open(csv_path, encoding="utf-8") as f:
assert len(f.members) == 1, "One archived file is expected, got: %s" % len(f.members)
content = f.extractfile(f.members[0])
if not use_header:
content.readline()
for line in progress_bar(content.readlines(), log):
row = line.decode("utf-8").strip().split(",")
if len(row[identifier_col]) <= max_identifier_len:
identifiers.append(row[split_identifier_col])
if shuffle:
numpy.random.shuffle(identifiers)
log.info("Number of identifiers after filtering: %s." % len(identifiers))
return identifiers


def prepare_features(csv_path: str, use_header: bool, max_identifier_len: int,
identifier_col: int, split_identifier_col: int, test_ratio: float,
padding: str, shuffle: bool = True) -> Tuple[numpy.array]:
"""
Prepare the features to train the identifier splitting task.

:param csv_path: path to the CSV file.
:param use_header: uses header as normal line (True) or treat as header line with column names.
:param max_identifier_len: maximum length of raw identifiers. Skip identifiers that are longer.
:param identifier_col: column in the CSV file for the raw identifier.
:param split_identifier_col: column in the CSV file for the split identifier.
:param shuffle: indicates whether to reorder the list of identifiers
at random after reading it.
:param test_ratio: Proportion of test samples used for evaluation.
:param padding: position where to add padding values:
after the intput sequence if "post", before if "pre".
:return: training and testing features to train the neural net for the splitting task.
"""
from keras.preprocessing.sequence import pad_sequences
log = logging.getLogger("prepare_features")

# read data from the input file
identifiers = read_identifiers(csv_path=csv_path, use_header=use_header,
max_identifier_len=max_identifier_len,
identifier_col=identifier_col,
split_identifier_col=split_identifier_col, shuffle=shuffle)

log.info("Converting identifiers to character indices")
log.info("Number of identifiers: %d, Average length: %d characters" %
(len(identifiers), numpy.mean([len(i) for i in identifiers])))

char2ind = {c: i + 1 for i, c in enumerate(sorted(string.ascii_lowercase))}

char_id_seq = []
splits = []
for identifier in identifiers:
# iterate through the identifier and convert to array of char indices & boolean split array
index_arr = []
split_arr = []
skip_char = False
for char in identifier.strip():
if char in char2ind:
index_arr.append(char2ind[char])
if skip_char:
skip_char = False
continue
split_arr.append(0)
elif char == " ":
split_arr.append(1)
skip_char = True
else:
log.warning("Unexpected symbol %s in identifier", char)
assert len(index_arr) == len(split_arr)
char_id_seq.append(index_arr)
splits.append(split_arr)

log.info("Number of subtokens: %d, Number of distinct characters: %d" %
(sum(sum(split_arr) for split_arr in splits) + len(identifiers),
len({i for index_arr in char_id_seq for i in index_arr})))

log.info("Train/test splitting...")
n_train = int((1 - test_ratio) * len(char_id_seq))
X_train = char_id_seq[:n_train]
X_test = char_id_seq[n_train:]
y_train = splits[:n_train]
y_test = splits[n_train:]
log.info("Number of train samples: %s, number of test samples: %s" % (len(X_train),
len(X_test)))
log.info("Padding the sequences...")
X_train = pad_sequences(X_train, maxlen=max_identifier_len, padding=padding)
X_test = pad_sequences(X_test, maxlen=max_identifier_len, padding=padding)
y_train = pad_sequences(y_train, maxlen=max_identifier_len, padding=padding)
y_test = pad_sequences(y_test, maxlen=max_identifier_len, padding=padding)

return X_train, X_test, y_train[:, :, None], y_test[:, :, None]
Loading