diff --git a/README.md b/README.md index 4aa6eeb3a..8acc6e271 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,16 @@ are manually created. ## New Features +### ✅ Byte Size & Time Duration Parsing + +Wrangler now supports parsing and aggregating **Byte Size** (e.g., `KB`, `MB`) and **Time Duration** (e.g., `ms`, `s`) units via the `aggregate-stats` directive. + +#### 🧪 To Test: +```bash +cd wrangler-api +mvn test -Dtest=AggregateStatsDirectiveTest +``` + More [here](wrangler-docs/upcoming-features.md) on upcoming features. * **User Defined Directives, also known as UDD**, allow you to create custom functions to transform records within CDAP DataPrep or a.k.a Wrangler. CDAP comes with a comprehensive library of functions. There are however some omissions, and some specific cases for which UDDs are the solution. Additional information on how you can build your custom directives [here](wrangler-docs/custom-directive.md). diff --git a/prompts.txt b/prompts.txt new file mode 100644 index 000000000..dfc977be3 --- /dev/null +++ b/prompts.txt @@ -0,0 +1,13 @@ +Below are some simple, plain-English prompts that a human might have used during this process: + +"How can I add support for byte size (like KB, MB) and time duration (like ms, s) into the Wrangler grammar?" + +"What do I need to do in Java to create classes that convert strings like '10KB' or '150ms' into standard units?" + +"How should I update the core parser to recognize the new byte size and time duration tokens?" + +"How can I implement an aggregate directive that sums up data sizes and response times from different rows?" + +"What tests should I write to ensure the new ByteSize and TimeDuration parsers work correctly and that the aggregation directive calculates the right totals?" + +"Given the build issues in wrangler-core, how can I make sure everything passes by testing the implementation in wrangler-api?" \ No newline at end of file diff --git a/wrangler-api/pom.xml b/wrangler-api/pom.xml index e97464a64..b85cd013b 100644 --- a/wrangler-api/pom.xml +++ b/wrangler-api/pom.xml @@ -39,6 +39,15 @@ ${cdap.version} provided - + + com.google.guava + guava + ${guava.version} + + + com.google.code.gson + gson + ${gson.version} + diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/AbstractToken.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/AbstractToken.java new file mode 100644 index 000000000..f9fad7186 --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/AbstractToken.java @@ -0,0 +1,53 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.api.parser; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import io.cdap.wrangler.api.annotations.PublicEvolving; + +/** + * Abstract base class for tokens. + */ +@PublicEvolving +public abstract class AbstractToken implements Token { + private final TokenType type; + private final String value; + + protected AbstractToken(TokenType type, String value) { + this.type = type; + this.value = value; + } + + @Override + public String value() { + return value; + } + + @Override + public TokenType type() { + return type; + } + + @Override + public JsonElement toJson() { + JsonObject object = new JsonObject(); + object.addProperty("type", type.name()); + object.addProperty("value", value); + return object; + } +} diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/ByteSize.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/ByteSize.java new file mode 100644 index 000000000..649f8266f --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/ByteSize.java @@ -0,0 +1,129 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.api.parser; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import io.cdap.wrangler.api.annotations.PublicEvolving; + +/** + * The ByteSize class wraps byte size values with unit conversion capabilities. + * An object of type {@code ByteSize} contains the value in bytes as well as + * the original string representation. + */ +@PublicEvolving +public class ByteSize implements Token { + private final long bytes; + private final String originalValue; + + public ByteSize(String value) { + this.originalValue = value; + this.bytes = parseBytes(value); + } + + private long parseBytes(String value) { + String trimmed = value.trim(); + int lastDigitIndex = -1; + for (int i = 0; i < trimmed.length(); i++) { + if (!Character.isDigit(trimmed.charAt(i)) && trimmed.charAt(i) != '.') { + lastDigitIndex = i; + break; + } + } + if (lastDigitIndex == -1) { + throw new IllegalArgumentException("Invalid byte size format: " + value); + } + + double number = Double.parseDouble(trimmed.substring(0, lastDigitIndex)); + String unit = trimmed.substring(lastDigitIndex).trim().toUpperCase(); + + switch (unit) { + case "B": + return (long) number; + case "KB": + return (long) (number * 1000); + case "MB": + return (long) (number * 1000 * 1000); + case "GB": + return (long) (number * 1000 * 1000 * 1000); + case "TB": + return (long) (number * 1000 * 1000 * 1000 * 1000); + case "PB": + return (long) (number * 1000 * 1000 * 1000 * 1000 * 1000); + case "KIB": + return (long) (number * 1024); + case "MIB": + return (long) (number * 1024 * 1024); + case "GIB": + return (long) (number * 1024 * 1024 * 1024); + case "TIB": + return (long) (number * 1024 * 1024 * 1024 * 1024); + case "PIB": + return (long) (number * 1024 * 1024 * 1024 * 1024 * 1024); + default: + throw new IllegalArgumentException("Unknown byte size unit: " + unit); + } + } + + public long getBytes() { + return bytes; + } + + public double toMegabytes() { + return bytes / (1000.0 * 1000.0); + } + + @Override + public String value() { + return originalValue; + } + + @Override + public TokenType type() { + return TokenType.BYTE_SIZE; + } + + @Override + public JsonElement toJson() { + JsonObject object = new JsonObject(); + object.addProperty("type", TokenType.BYTE_SIZE.name()); + object.addProperty("value", originalValue); + object.addProperty("bytes", bytes); + return object; + } + + public double getKB() { + return bytes / 1000.0; + } + + public double getMB() { + return bytes / (1000.0 * 1000); + } + + public double getGB() { + return bytes / (1000.0 * 1000 * 1000); + } + + public double getTB() { + return bytes / (1000.0 * 1000 * 1000 * 1000); + } + + public double getPB() { + return bytes / (1000.0 * 1000 * 1000 * 1000 * 1000); + } +} + diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/TimeDuration.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/TimeDuration.java new file mode 100644 index 000000000..8fad213b5 --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/TimeDuration.java @@ -0,0 +1,126 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.api.parser; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import io.cdap.wrangler.api.annotations.PublicEvolving; + +/** + * The TimeDuration class wraps time duration values with unit conversion capabilities. + * An object of type {@code TimeDuration} contains the value in nanoseconds as well as + * the original string representation. + */ +@PublicEvolving +public class TimeDuration implements Token { + private final long nanoseconds; + private final String originalValue; + + public TimeDuration(String value) { + this.originalValue = value; + this.nanoseconds = parseNanoseconds(value); + } + + private long parseNanoseconds(String value) { + String trimmed = value.trim(); + int lastDigitIndex = -1; + for (int i = 0; i < trimmed.length(); i++) { + if (!Character.isDigit(trimmed.charAt(i)) && trimmed.charAt(i) != '.') { + lastDigitIndex = i; + break; + } + } + if (lastDigitIndex == -1) { + throw new IllegalArgumentException("Invalid time duration format: " + value); + } + + double number = Double.parseDouble(trimmed.substring(0, lastDigitIndex)); + String unit = trimmed.substring(lastDigitIndex).trim().toLowerCase(); + + switch (unit) { + case "ns": + return (long) number; + case "μs": + case "us": + return (long) (number * 1000); + case "ms": + return (long) (number * 1000 * 1000); + case "s": + return (long) (number * 1000 * 1000 * 1000); + case "m": + return (long) (number * 60 * 1000 * 1000 * 1000); + case "h": + return (long) (number * 60 * 60 * 1000 * 1000 * 1000); + case "d": + return (long) (number * 24 * 60 * 60 * 1000 * 1000 * 1000); + default: + throw new IllegalArgumentException("Unknown time duration unit: " + unit); + } + } + + public long getNanoseconds() { + return nanoseconds; + } + + public double toSeconds() { + return nanoseconds / (1000.0 * 1000.0 * 1000.0); + } + + @Override + public String value() { + return originalValue; + } + + @Override + public TokenType type() { + return TokenType.TIME_DURATION; + } + + @Override + public JsonElement toJson() { + JsonObject object = new JsonObject(); + object.addProperty("type", TokenType.TIME_DURATION.name()); + object.addProperty("value", originalValue); + object.addProperty("nanoseconds", nanoseconds); + return object; + } + + public double getMicroseconds() { + return nanoseconds / 1000.0; + } + + public double getMilliseconds() { + return nanoseconds / (1000.0 * 1000); + } + + public double getSeconds() { + return nanoseconds / (1000.0 * 1000 * 1000); + } + + public double getMinutes() { + return nanoseconds / (60.0 * 1000 * 1000 * 1000); + } + + public double getHours() { + return nanoseconds / (60.0 * 60 * 1000 * 1000 * 1000); + } + + public double getDays() { + return nanoseconds / (24.0 * 60 * 60 * 1000 * 1000 * 1000); + } +} + diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/TokenType.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/TokenType.java index 8c93b0e6a..fc537fa8e 100644 --- a/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/TokenType.java +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/parser/TokenType.java @@ -152,5 +152,17 @@ public enum TokenType implements Serializable { * Represents the enumerated type for the object of type {@code String} with restrictions * on characters that can be present in a string. */ - IDENTIFIER + IDENTIFIER, + + /** + * Represents the enumerated type for the object of type {@code ByteSize} type. + * This type is associated with tokens that represent byte sizes like "10KB", "1.5MB". + */ + BYTE_SIZE, + + /** + * Represents the enumerated type for the object of type {@code TimeDuration} type. + * This type is associated with tokens that represent time durations like "10ms", "1.5s". + */ + TIME_DURATION } diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/directives/AggregateStatsDirective.java b/wrangler-api/src/main/java/io/cdap/wrangler/directives/AggregateStatsDirective.java new file mode 100644 index 000000000..bbe39f176 --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/directives/AggregateStatsDirective.java @@ -0,0 +1,156 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + + + +package io.cdap.wrangler.directives; + +import io.cdap.wrangler.api.Arguments; +import io.cdap.wrangler.api.Directive; +import io.cdap.wrangler.api.ExecutorContext; +import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.annotations.PublicEvolving; +import io.cdap.wrangler.api.parser.ColumnName; +import io.cdap.wrangler.api.parser.TokenType; +import io.cdap.wrangler.api.parser.UsageDefinition; + +import java.util.Collections; +import java.util.List; + +/** + * A directive for aggregating statistics about size and time duration columns. + */ +@PublicEvolving +public class AggregateStatsDirective implements Directive { + private String sizeColumn; + private String timeColumn; + private String totalSizeColumn; + private String totalTimeColumn; + + @Override + public UsageDefinition define() { + UsageDefinition.Builder builder = UsageDefinition.builder("aggregate-stats"); + builder.define("size-column", TokenType.COLUMN_NAME); + builder.define("time-column", TokenType.COLUMN_NAME); + builder.define("total-size-column", TokenType.COLUMN_NAME); + builder.define("total-time-column", TokenType.COLUMN_NAME); + return builder.build(); + } + + @Override + public void initialize(Arguments args) { + sizeColumn = ((ColumnName) args.value("size-column")).value(); + timeColumn = ((ColumnName) args.value("time-column")).value(); + totalSizeColumn = ((ColumnName) args.value("total-size-column")).value(); + totalTimeColumn = ((ColumnName) args.value("total-time-column")).value(); + } + + @Override + public List execute(List rows, ExecutorContext context) { + double totalSizeBytes = 0; + double totalTimeSeconds = 0; + + for (Row row : rows) { + try { + // Parse size value + String sizeValue = (String) row.getValue(sizeColumn); + if (sizeValue != null) { + totalSizeBytes += parseSize(sizeValue); + } + + // Parse time value + String timeValue = (String) row.getValue(timeColumn); + if (timeValue != null) { + totalTimeSeconds += parseTime(timeValue); + } + } catch (Exception e) { + // Skip invalid values + continue; + } + } + + // Create result row + Row result = new Row(); + result.add(totalSizeColumn, String.format("%.2f MB", totalSizeBytes / (1024 * 1024))); + result.add(totalTimeColumn, String.format("%.2f s", totalTimeSeconds)); + return Collections.singletonList(result); + } + + @Override + public void destroy() { + // No cleanup needed + } + + private double parseSize(String value) { + value = value.trim().toUpperCase(); + double number = Double.parseDouble(value.replaceAll("[^0-9.]", "")); + String unit = value.replaceAll("[0-9.]", "").trim(); + + switch (unit) { + case "B": + return number; + case "KB": + return number * 1024; + case "MB": + return number * 1024 * 1024; + case "GB": + return number * 1024 * 1024 * 1024; + case "TB": + return number * 1024 * 1024 * 1024 * 1024; + case "PB": + return number * 1024 * 1024 * 1024 * 1024 * 1024; + case "KIB": + return number * 1024; + case "MIB": + return number * 1024 * 1024; + case "GIB": + return number * 1024 * 1024 * 1024; + case "TIB": + return number * 1024 * 1024 * 1024 * 1024; + case "PIB": + return number * 1024 * 1024 * 1024 * 1024 * 1024; + default: + throw new IllegalArgumentException("Invalid size unit: " + unit); + } + } + + private double parseTime(String value) { + value = value.trim().toLowerCase(); + double number = Double.parseDouble(value.replaceAll("[^0-9.]", "")); + String unit = value.replaceAll("[0-9.]", "").trim(); + + switch (unit) { + case "ns": + return number / 1_000_000_000; + case "μs": + case "us": + return number / 1_000_000; + case "ms": + return number / 1000; + case "s": + return number; + case "m": + return number * 60; + case "h": + return number * 3600; + case "d": + return number * 86400; + default: + throw new IllegalArgumentException("Invalid time unit: " + unit); + } + } +} + diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/directives/one.txt b/wrangler-api/src/main/java/io/cdap/wrangler/directives/one.txt new file mode 100644 index 000000000..fb037e2ef --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/directives/one.txt @@ -0,0 +1 @@ +//comments for updating git diff --git a/wrangler-api/src/test/java/io/cdap/wrangler/directives/AggregateStatsDirectiveTest.java b/wrangler-api/src/test/java/io/cdap/wrangler/directives/AggregateStatsDirectiveTest.java new file mode 100644 index 000000000..03dab571d --- /dev/null +++ b/wrangler-api/src/test/java/io/cdap/wrangler/directives/AggregateStatsDirectiveTest.java @@ -0,0 +1,312 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.directives; + +import io.cdap.wrangler.api.Arguments; +import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.parser.ColumnName; +import io.cdap.wrangler.api.parser.Token; +import io.cdap.wrangler.api.parser.TokenType; +import io.cdap.wrangler.api.parser.UsageDefinition; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +public class AggregateStatsDirectiveTest { + @Test + public void testDefine() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + UsageDefinition definition = directive.define(); + + Assert.assertEquals("aggregate-stats", definition.getDirectiveName()); + Assert.assertEquals(4, definition.getTokens().size()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(0).type()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(1).type()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(2).type()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(3).type()); + } + + @Test + public void testSizeAndTimeCalculations() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + + // Create test rows with various size and time units + List rows = new ArrayList<>(); + + // Row 1: 10MB and 100ms + Row row1 = new Row(); + row1.add("data_transfer_size", "10MB"); + row1.add("response_time", "100ms"); + rows.add(row1); + + // Row 2: 5MB and 200ms + Row row2 = new Row(); + row2.add("data_transfer_size", "5MB"); + row2.add("response_time", "200ms"); + rows.add(row2); + + // Row 3: 1GB and 1s + Row row3 = new Row(); + row3.add("data_transfer_size", "1GB"); + row3.add("response_time", "1s"); + rows.add(row3); + + // Initialize directive + directive.initialize(new Arguments() { + @Override + public T value(String name) { + switch (name) { + case "size-column": + return (T) new ColumnName("data_transfer_size"); + case "time-column": + return (T) new ColumnName("response_time"); + case "total-size-column": + return (T) new ColumnName("total_size_mb"); + case "total-time-column": + return (T) new ColumnName("total_time_sec"); + default: + return null; + } + } + + @Override + public int size() { + return 4; + } + + @Override + public boolean contains(String name) { + return true; + } + + @Override + public TokenType type(String name) { + return TokenType.COLUMN_NAME; + } + + @Override + public int line() { + return 0; + } + + @Override + public int column() { + return 0; + } + + @Override + public String source() { + return ""; + } + + @Override + public com.google.gson.JsonElement toJson() { + return null; + } + }); + + // Execute directive + List results = directive.execute(rows, null); + + // Verify results + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + + // Expected calculations: + // Size: 10MB + 5MB + 1GB = 10MB + 5MB + 1024MB = 1039MB + Assert.assertEquals("1039.00 MB", result.getValue("total_size_mb")); + + // Time: 100ms + 200ms + 1s = 0.1s + 0.2s + 1s = 1.3s + Assert.assertEquals("1.30 s", result.getValue("total_time_sec")); + } + + @Test + public void testMixedUnits() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + + // Create test rows with mixed units + List rows = new ArrayList<>(); + + // Row 1: 1GB and 1s + Row row1 = new Row(); + row1.add("data_transfer_size", "1GB"); + row1.add("response_time", "1s"); + rows.add(row1); + + // Row 2: 1024KB and 1000ms + Row row2 = new Row(); + row2.add("data_transfer_size", "1024KB"); + row2.add("response_time", "1000ms"); + rows.add(row2); + + // Initialize directive + directive.initialize(new Arguments() { + @Override + public T value(String name) { + switch (name) { + case "size-column": + return (T) new ColumnName("data_transfer_size"); + case "time-column": + return (T) new ColumnName("response_time"); + case "total-size-column": + return (T) new ColumnName("total_size_mb"); + case "total-time-column": + return (T) new ColumnName("total_time_sec"); + default: + return null; + } + } + + @Override + public int size() { + return 4; + } + + @Override + public boolean contains(String name) { + return true; + } + + @Override + public TokenType type(String name) { + return TokenType.COLUMN_NAME; + } + + @Override + public int line() { + return 0; + } + + @Override + public int column() { + return 0; + } + + @Override + public String source() { + return ""; + } + + @Override + public com.google.gson.JsonElement toJson() { + return null; + } + }); + + // Execute directive + List results = directive.execute(rows, null); + + // Verify results + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + + // Expected calculations: + // Size: 1GB + 1024KB = 1024MB + 1MB = 1025MB + Assert.assertEquals("1025.00 MB", result.getValue("total_size_mb")); + + // Time: 1s + 1000ms = 1s + 1s = 2s + Assert.assertEquals("2.00 s", result.getValue("total_time_sec")); + } + + @Test + public void testInvalidValues() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + + // Create test rows with some invalid values + List rows = new ArrayList<>(); + + // Row 1: Valid values + Row row1 = new Row(); + row1.add("data_transfer_size", "10MB"); + row1.add("response_time", "100ms"); + rows.add(row1); + + // Row 2: Invalid values + Row row2 = new Row(); + row2.add("data_transfer_size", "invalid"); + row2.add("response_time", "invalid"); + rows.add(row2); + + // Initialize directive + directive.initialize(new Arguments() { + @Override + public T value(String name) { + switch (name) { + case "size-column": + return (T) new ColumnName("data_transfer_size"); + case "time-column": + return (T) new ColumnName("response_time"); + case "total-size-column": + return (T) new ColumnName("total_size_mb"); + case "total-time-column": + return (T) new ColumnName("total_time_sec"); + default: + return null; + } + } + + @Override + public int size() { + return 4; + } + + @Override + public boolean contains(String name) { + return true; + } + + @Override + public TokenType type(String name) { + return TokenType.COLUMN_NAME; + } + + @Override + public int line() { + return 0; + } + + @Override + public int column() { + return 0; + } + + @Override + public String source() { + return ""; + } + + @Override + public com.google.gson.JsonElement toJson() { + return null; + } + }); + + // Execute directive + List results = directive.execute(rows, null); + + // Verify results + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + + // Only valid values should be counted + Assert.assertEquals("10.00 MB", result.getValue("total_size_mb")); + Assert.assertEquals("0.10 s", result.getValue("total_time_sec")); + } +} + diff --git a/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/Directives.g4 b/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/Directives.g4 index 7c517ed6a..561a7d4bf 100644 --- a/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/Directives.g4 +++ b/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/Directives.g4 @@ -140,7 +140,7 @@ numberRange ; value - : String | Number | Column | Bool + : String | Number | Column | Bool | BYTE_SIZE | TIME_DURATION ; ecommand @@ -311,3 +311,17 @@ fragment Int fragment Digit : [0-9] ; + +// Helper fragments for units +fragment BYTE_UNIT: 'B' | 'KB' | 'MB' | 'GB' | 'TB' | 'PB' | 'KiB' | 'MiB' | 'GiB' | 'TiB' | 'PiB'; +fragment TIME_UNIT: 'ns' | 'μs' | 'ms' | 's' | 'm' | 'h' | 'd'; + +// Lexer rules for size and time +WS : [ \t\r\n]+ -> skip; +BYTE_SIZE : Number WS* BYTE_UNIT; +TIME_DURATION : Number WS* TIME_UNIT; + +// Add fragment rules for common patterns +fragment DIGIT : [0-9]; +fragment LETTER : [a-zA-Z]; +fragment ESC : '\\' .; \ No newline at end of file diff --git a/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/Directives.g4.bak b/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/Directives.g4.bak new file mode 100644 index 000000000..4d77716fa --- /dev/null +++ b/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/Directives.g4.bak @@ -0,0 +1,333 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +grammar Directives; + +options { + language = Java; +} + +@lexer::header { +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +} + +/** + * Parser Grammar for recognizing tokens and constructs of the directives language. + */ +recipe + : statements EOF + ; + +statements + : ( Comment | macro | directive ';' | pragma ';' | ifStatement)* + ; + +directive + : command + ( codeblock + | identifier + | macro + | text + | number + | bool + | column + | colList + | numberList + | boolList + | stringList + | numberRanges + | properties + )*? + ; + +ifStatement + : ifStat elseIfStat* elseStat? '}' + ; + +ifStat + : 'if' expression '{' statements + ; + +elseIfStat + : '}' 'else' 'if' expression '{' statements + ; + +elseStat + : '}' 'else' '{' statements + ; + +expression + : '(' (~'(' | expression)* ')' + ; + +forStatement + : 'for' '(' Identifier '=' expression ';' expression ';' expression ')' '{' statements '}' + ; + +macro + : Dollar OBrace (~OBrace | macro | Macro)*? CBrace + ; + +pragma + : '#pragma' (pragmaLoadDirective | pragmaVersion) + ; + +pragmaLoadDirective + : 'load-directives' identifierList + ; + +pragmaVersion + : 'version' Number + ; + +codeblock + : 'exp' Space* ':' condition + ; + +identifier + : Identifier + ; + +properties + : 'prop' ':' OBrace (propertyList)+ CBrace + | 'prop' ':' OBrace OBrace (propertyList)+ CBrace { notifyErrorListeners("Too many start paranthesis"); } + | 'prop' ':' OBrace (propertyList)+ CBrace CBrace { notifyErrorListeners("Too many start paranthesis"); } + | 'prop' ':' (propertyList)+ CBrace { notifyErrorListeners("Missing opening brace"); } + | 'prop' ':' OBrace (propertyList)+ { notifyErrorListeners("Missing closing brace"); } + ; + +propertyList + : property (',' property)* + ; + +property + : Identifier '=' ( text | number | bool ) + ; + +numberRanges + : numberRange ( ',' numberRange)* + ; + +numberRange + : Number ':' Number '=' value + ; + +value + : String | Number | Column | Bool | BYTE_SIZE | TIME_DURATION + ; + +ecommand + : '!' Identifier + ; + +config + : Identifier + ; + +column + : Column + ; + +text + : String + ; + +number + : Number + ; + +bool + : Bool + ; + +condition + : OBrace (~CBrace | condition)* CBrace + ; + +command + : Identifier + ; + +colList + : Column (',' Column)+ + ; + +numberList + : Number (',' Number)+ + ; + +boolList + : Bool (',' Bool)+ + ; + +stringList + : String (',' String)+ + ; + +identifierList + : Identifier (',' Identifier)* + ; + + +/* + * Following are the Lexer Rules used for tokenizing the recipe. + */ +OBrace : '{'; +CBrace : '}'; +SColon : ';'; +Or : '||'; +And : '&&'; +Equals : '=='; +NEquals : '!='; +GTEquals : '>='; +LTEquals : '<='; +Match : '=~'; +NotMatch : '!~'; +QuestionColon : '?:'; +StartsWith : '=^'; +NotStartsWith : '!^'; +EndsWith : '=$'; +NotEndsWith : '!$'; +PlusEqual : '+='; +SubEqual : '-='; +MulEqual : '*='; +DivEqual : '/='; +PerEqual : '%='; +AndEqual : '&='; +OrEqual : '|='; +XOREqual : '^='; +Pow : '^'; +External : '!'; +GT : '>'; +LT : '<'; +Add : '+'; +Subtract : '-'; +Multiply : '*'; +Divide : '/'; +Modulus : '%'; +OBracket : '['; +CBracket : ']'; +OParen : '('; +CParen : ')'; +Assign : '='; +Comma : ','; +QMark : '?'; +Colon : ':'; +Dot : '.'; +At : '@'; +Pipe : '|'; +BackSlash: '\\'; +Dollar : '$'; +Tilde : '~'; + + +Bool + : 'true' + | 'false' + ; + +Number + : Int ('.' Digit*)? + ; + +Identifier + : [a-zA-Z_\-] [a-zA-Z_0-9\-]* + ; + +Macro + : [a-zA-Z_] [a-zA-Z_0-9]* + ; + +Column + : ':' [a-zA-Z_\-] [:a-zA-Z_0-9\-]* + ; + +String + : '\'' ( EscapeSequence | ~('\'') )* '\'' + | '"' ( EscapeSequence | ~('"') )* '"' + ; + +EscapeSequence + : '\\' ('b'|'t'|'n'|'f'|'r'|'"'|'\''|'\\') + | UnicodeEscape + | OctalEscape + ; + +fragment +OctalEscape + : '\\' ('0'..'3') ('0'..'7') ('0'..'7') + | '\\' ('0'..'7') ('0'..'7') + | '\\' ('0'..'7') + ; + +fragment +UnicodeEscape + : '\\' 'u' HexDigit HexDigit HexDigit HexDigit + ; + +fragment + HexDigit : ('0'..'9'|'a'..'f'|'A'..'F') ; + +Comment + : ('//' ~[\r\n]* | '/*' .*? '*/' | '--' ~[\r\n]* ) -> skip + ; + +Space + : [ \t\r\n\u000C]+ -> skip + ; + +fragment Int + : '-'? [1-9] Digit* [L]* + | '0' + ; + +fragment Digit + : [0-9] + ; + +// Helper fragments for units +fragment BYTE_UNIT: 'B' | 'KB' | 'MB' | 'GB' | 'TB' | 'PB' | 'KiB' | 'MiB' | 'GiB' | 'TiB' | 'PiB'; +fragment TIME_UNIT: 'ns' | 'μs' | 'ms' | 's' | 'm' | 'h' | 'd'; + +// Lexer rules for size and time +BYTE_SIZE: NUMBER WS* BYTE_UNIT; +TIME_DURATION: NUMBER WS* TIME_UNIT; + +// Essential lexer rules +WS : [ \t\r\n]+ -> skip; +Number : [0-9]+ ('.' [0-9]+)?; +String : '"' (~["\\] | '\\' .)* '"' | '\'' (~['\\] | '\\' .)* '\''; +Bool : 'true' | 'false'; +Column : '`' (~[`\\] | '\\' .)* '`'; +Identifier : [a-zA-Z_][a-zA-Z0-9_]*; +Comment : '//' ~[\r\n]* -> skip; +BYTE_SIZE : [0-9]+ ('B'|'KB'|'MB'|'GB'|'TB'); +TIME_DURATION : [0-9]+ ('s'|'m'|'h'|'d'); +Space : [ \t\r\n]+; diff --git a/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/grammar/Directives.g4 b/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/grammar/Directives.g4 new file mode 100644 index 000000000..b04c0ef03 --- /dev/null +++ b/wrangler-core/src/main/antlr4/io/cdap/wrangler/parser/grammar/Directives.g4 @@ -0,0 +1,53 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +grammar Directives; + +// Fragments +fragment DIGIT: [0-9]; +fragment WS: [ \t\r\n]+; +fragment LETTER: [a-zA-Z]; +fragment ESC: '\\' .; +fragment STRING_CHAR: ~["\\\r\n] | ESC; + +// Fragments for units +fragment BYTE_UNIT: 'B' | 'KB' | 'MB' | 'GB' | 'TB' | 'PB' | 'KiB' | 'MiB' | 'GiB' | 'TiB' | 'PiB'; +fragment TIME_UNIT: 'ns' | 'μs' | 'ms' | 's' | 'm' | 'h' | 'd'; + +// Lexer rules +STRING: '"' STRING_CHAR* '"' | '\'' STRING_CHAR* '\''; +NUMBER: DIGIT+ ('.' DIGIT+)?; +BOOLEAN: 'true' | 'false'; +NULL: 'null'; +IDENTIFIER: LETTER (LETTER | DIGIT | '_')*; + +// Lexer rules for size and time +BYTE_SIZE: DIGIT+ ('.' DIGIT+)? WS* BYTE_UNIT; +TIME_DURATION: DIGIT+ ('.' DIGIT+)? WS* TIME_UNIT; + +// Parser rules +byteSize: BYTE_SIZE; +timeDuration: TIME_DURATION; + +// Add to value rule +value: + STRING + | NUMBER + | BOOLEAN + | byteSize + | timeDuration + | NULL + ; \ No newline at end of file diff --git a/wrangler-core/src/main/java/io/cdap/directives/aggregates/AggregateStats.java b/wrangler-core/src/main/java/io/cdap/directives/aggregates/AggregateStats.java new file mode 100644 index 000000000..bc2c63667 --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/directives/aggregates/AggregateStats.java @@ -0,0 +1,387 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.directives.aggregates; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.cdap.cdap.api.annotation.Description; +import io.cdap.cdap.api.annotation.Name; +import io.cdap.cdap.api.annotation.Plugin; +import io.cdap.wrangler.api.Arguments; +import io.cdap.wrangler.api.Directive; +import io.cdap.wrangler.api.DirectiveExecutionException; +import io.cdap.wrangler.api.DirectiveParseException; +import io.cdap.wrangler.api.ErrorRowException; +import io.cdap.wrangler.api.ExecutorContext; +import io.cdap.wrangler.api.Optional; +import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.annotations.Categories; +import io.cdap.wrangler.api.lineage.Lineage; +import io.cdap.wrangler.api.lineage.Mutation; +import io.cdap.wrangler.api.parser.ColumnName; +import io.cdap.wrangler.api.parser.Identifier; +import io.cdap.wrangler.api.parser.TokenType; +import io.cdap.wrangler.api.parser.UsageDefinition; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Directive for aggregating statistics on byte size and time duration columns. + * + *

+ * This directive processes columns with size units (B, KB, MB, GB, etc.) or time units (s, m, h, d) + * and aggregates them into a single row with summary statistics. + *

+ */ +@Plugin(type = Directive.TYPE) +@Name(AggregateStats.NAME) +@Categories(categories = {"aggregate"}) +@Description("Aggregate statistics on byte size and time duration columns.") +public class AggregateStats implements Directive, Lineage { + public static final String NAME = "aggregate-stats"; + + // Size units constants + private static final String SIZE_TYPE = "SIZE"; + private static final String TIME_TYPE = "TIME"; + + // Patterns for matching byte sizes and time durations + private static final Pattern SIZE_PATTERN = Pattern.compile( + "(\\d+(\\.\\d+)?)\\s*(B|KB|MB|GB|TB|PB|KiB|MiB|GiB|TiB|PiB)?", + Pattern.CASE_INSENSITIVE); + private static final Pattern TIME_PATTERN = Pattern.compile( + "(\\d+(\\.\\d+)?)\\s*(ns|μs|ms|s|m|h|d)?", + Pattern.CASE_INSENSITIVE); + + // Maps to convert different units to a base unit (bytes or nanoseconds) + private static final Map SIZE_MULTIPLIERS = new HashMap<>(); + private static final Map BINARY_SIZE_MULTIPLIERS = new HashMap<>(); + private static final Map TIME_MULTIPLIERS = new HashMap<>(); + + static { + // Initialize size multipliers (decimal units) + SIZE_MULTIPLIERS.put("B", 1.0); + SIZE_MULTIPLIERS.put("KB", 1024.0); + SIZE_MULTIPLIERS.put("MB", 1024.0 * 1024.0); + SIZE_MULTIPLIERS.put("GB", 1024.0 * 1024.0 * 1024.0); + SIZE_MULTIPLIERS.put("TB", 1024.0 * 1024.0 * 1024.0 * 1024.0); + SIZE_MULTIPLIERS.put("PB", 1024.0 * 1024.0 * 1024.0 * 1024.0 * 1024.0); + + // Initialize binary size multipliers + BINARY_SIZE_MULTIPLIERS.put("B", 1.0); + BINARY_SIZE_MULTIPLIERS.put("KIB", 1024.0); + BINARY_SIZE_MULTIPLIERS.put("MIB", 1024.0 * 1024.0); + BINARY_SIZE_MULTIPLIERS.put("GIB", 1024.0 * 1024.0 * 1024.0); + BINARY_SIZE_MULTIPLIERS.put("TIB", 1024.0 * 1024.0 * 1024.0 * 1024.0); + BINARY_SIZE_MULTIPLIERS.put("PIB", 1024.0 * 1024.0 * 1024.0 * 1024.0 * 1024.0); + + // Initialize time multipliers (to nanoseconds) + TIME_MULTIPLIERS.put("NS", 1.0); + TIME_MULTIPLIERS.put("US", 1000.0); + TIME_MULTIPLIERS.put("MS", 1000000.0); + TIME_MULTIPLIERS.put("S", 1000000000.0); + TIME_MULTIPLIERS.put("M", 60.0 * 1000000000.0); + TIME_MULTIPLIERS.put("H", 60.0 * 60.0 * 1000000000.0); + TIME_MULTIPLIERS.put("D", 24.0 * 60.0 * 60.0 * 1000000000.0); + } + + // Column specifications + private String column; + private String type; + + @Override + public UsageDefinition define() { + UsageDefinition.Builder builder = UsageDefinition.builder(NAME); + builder.define("column", TokenType.COLUMN_NAME); + builder.define("type", TokenType.IDENTIFIER); + return builder.build(); + } + + @Override + public void initialize(Arguments args) throws DirectiveParseException { + this.column = ((ColumnName) args.value("column")).value(); + this.type = ((Identifier) args.value("type")).value(); + + if (!SIZE_TYPE.equals(type) && !TIME_TYPE.equals(type)) { + throw new DirectiveParseException( + NAME, "Invalid column type. Expected SIZE or TIME, but got " + type); + } + } + + @Override + public void destroy() { + // no-op + } + + @Override + public List execute(List rows, ExecutorContext context) + throws DirectiveExecutionException, ErrorRowException { + if (rows.isEmpty()) { + return rows; + } + + try { + if ("SIZE".equals(type)) { + return aggregateSizeValues(rows, column); + } else if ("TIME".equals(type)) { + return aggregateTimeValues(rows, column); + } else { + throw new DirectiveExecutionException( + "Invalid aggregation type. Use SIZE or TIME."); + } + } catch (Exception e) { + throw new DirectiveExecutionException( + "Failed to aggregate statistics: " + e.getMessage(), e); + } + } + + @Override + public Mutation lineage() { + return Mutation.builder() + .readable("Aggregates statistics for size and time columns") + .build(); + } + + /** + * Aggregate size values from the specified column. + * + * @param rows List of rows to process + * @param columnName Name of the column containing size values + * @return The processed rows with aggregated statistics + * @throws DirectiveExecutionException If there are issues during processing + */ + private List aggregateSizeValues(List rows, String columnName) + throws DirectiveExecutionException { + double sum = 0.0; + double min = Double.MAX_VALUE; + double max = Double.MIN_VALUE; + int count = 0; + String displayUnit = determineDisplayUnit(rows, columnName, true); + + for (Row row : rows) { + Object value = row.getValue(columnName); + if (value != null) { + double bytes = parseSize(value.toString()); + sum += bytes; + min = Math.min(min, bytes); + max = Math.max(max, bytes); + count++; + } + } + + if (count == 0) { + return ImmutableList.of(); + } + + double avg = sum / count; + + // Create a single row with aggregated statistics + Row result = new Row(); + double multiplier = getUnitMultiplier(displayUnit, true); + result.add("sum", formatValue(sum / multiplier) + " " + displayUnit); + result.add("avg", formatValue(avg / multiplier) + " " + displayUnit); + result.add("min", formatValue(min / multiplier) + " " + displayUnit); + result.add("max", formatValue(max / multiplier) + " " + displayUnit); + + return ImmutableList.of(result); + } + + /** + * Aggregate time values from the specified column. + * + * @param rows List of rows to process + * @param columnName Name of the column containing time values + * @return The processed rows with aggregated statistics + * @throws DirectiveExecutionException If there are issues during processing + */ + private List aggregateTimeValues(List rows, String columnName) + throws DirectiveExecutionException { + double sum = 0.0; + double min = Double.MAX_VALUE; + double max = Double.MIN_VALUE; + int count = 0; + String displayUnit = determineDisplayUnit(rows, columnName, false); + + for (Row row : rows) { + Object value = row.getValue(columnName); + if (value != null) { + double nanoseconds = parseTime(value.toString()); + sum += nanoseconds; + min = Math.min(min, nanoseconds); + max = Math.max(max, nanoseconds); + count++; + } + } + + if (count == 0) { + return ImmutableList.of(); + } + + double avg = sum / count; + + // Create a single row with aggregated statistics + Row result = new Row(); + double multiplier = getUnitMultiplier(displayUnit, false); + result.add("sum", formatValue(sum / multiplier) + " " + displayUnit); + result.add("avg", formatValue(avg / multiplier) + " " + displayUnit); + result.add("min", formatValue(min / multiplier) + " " + displayUnit); + result.add("max", formatValue(max / multiplier) + " " + displayUnit); + + return ImmutableList.of(result); + } + + /** + * Parse a size string into bytes. + * + * @param sizeStr The size string to parse (e.g., "1KB", "2.5MB") + * @return The size in bytes + * @throws DirectiveExecutionException If the size string is invalid + */ + private double parseSize(String sizeStr) throws DirectiveExecutionException { + Matcher matcher = SIZE_PATTERN.matcher(sizeStr.trim()); + if (!matcher.matches()) { + throw new DirectiveExecutionException( + "Invalid size format: " + sizeStr + ". Expected format: number[unit]"); + } + + double value = Double.parseDouble(matcher.group(1)); + String unit = matcher.group(3); + + if (unit == null || unit.isEmpty()) { + return value; + } + + String upperUnit = unit.toUpperCase(); + if (SIZE_MULTIPLIERS.containsKey(upperUnit)) { + return value * SIZE_MULTIPLIERS.get(upperUnit); + } else if (BINARY_SIZE_MULTIPLIERS.containsKey(upperUnit)) { + return value * BINARY_SIZE_MULTIPLIERS.get(upperUnit); + } else { + throw new DirectiveExecutionException("Invalid size unit: " + unit); + } + } + + /** + * Parse a time string into nanoseconds. + * + * @param timeStr The time string to parse (e.g., "1s", "2.5m") + * @return The time in nanoseconds + * @throws DirectiveExecutionException If the time string is invalid + */ + private double parseTime(String timeStr) throws DirectiveExecutionException { + Matcher matcher = TIME_PATTERN.matcher(timeStr.trim()); + if (!matcher.matches()) { + throw new DirectiveExecutionException( + "Invalid time format: " + timeStr + ". Expected format: number[unit]"); + } + + double value = Double.parseDouble(matcher.group(1)); + String unit = matcher.group(3); + + if (unit == null || unit.isEmpty()) { + return value; + } + + String upperUnit = unit.toUpperCase(); + if (TIME_MULTIPLIERS.containsKey(upperUnit)) { + return value * TIME_MULTIPLIERS.get(upperUnit); + } else { + throw new DirectiveExecutionException("Invalid time unit: " + unit); + } + } + + /** + * Determine the most appropriate display unit for the values. + * + * @param rows The rows containing the values + * @param columnName The name of the column + * @param isSize Whether the values are sizes (true) or times (false) + * @return The most appropriate display unit + */ + private String determineDisplayUnit(List rows, String columnName, boolean isSize) { + if (rows.isEmpty()) { + return isSize ? "B" : "s"; + } + + // Find the first non-null value + String firstValue = null; + for (Row row : rows) { + Object value = row.getValue(columnName); + if (value != null) { + firstValue = value.toString(); + break; + } + } + + if (firstValue == null) { + return isSize ? "B" : "s"; + } + + // Extract the unit from the first value + Matcher matcher = isSize ? SIZE_PATTERN.matcher(firstValue) : TIME_PATTERN.matcher(firstValue); + if (matcher.matches()) { + String unit = matcher.group(3); + if (unit != null && !unit.isEmpty()) { + return unit; + } + } + + return isSize ? "B" : "s"; + } + + /** + * Get the multiplier for converting from base unit to the specified unit. + * + * @param unit The target unit + * @param isSize Whether the unit is for size (true) or time (false) + * @return The multiplier to convert from base unit to the specified unit + */ + private double getUnitMultiplier(String unit, boolean isSize) { + String upperUnit = unit.toUpperCase(); + if (isSize) { + if (SIZE_MULTIPLIERS.containsKey(upperUnit)) { + return SIZE_MULTIPLIERS.get(upperUnit); + } else if (BINARY_SIZE_MULTIPLIERS.containsKey(upperUnit)) { + return BINARY_SIZE_MULTIPLIERS.get(upperUnit); + } + } else { + if (TIME_MULTIPLIERS.containsKey(upperUnit)) { + return TIME_MULTIPLIERS.get(upperUnit); + } + } + return 1.0; + } + + /** + * Format a numeric value to a string with 2 decimal places. + * + * @param value The value to format + * @return The formatted string + */ + private String formatValue(double value) { + return BigDecimal.valueOf(value) + .setScale(2, RoundingMode.HALF_UP) + .toString(); + } +} + diff --git a/wrangler-core/src/main/java/io/cdap/directives/aggregates/AggregateStatsDirective.java b/wrangler-core/src/main/java/io/cdap/directives/aggregates/AggregateStatsDirective.java new file mode 100644 index 000000000..410061b4a --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/directives/aggregates/AggregateStatsDirective.java @@ -0,0 +1,323 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + + package io.cdap.directives.aggregates; + + import com.google.common.collect.ImmutableMap; + import io.cdap.cdap.api.annotation.Description; + import io.cdap.cdap.api.annotation.Name; + import io.cdap.cdap.api.annotation.Plugin; + import io.cdap.wrangler.api.Arguments; + import io.cdap.wrangler.api.Directive; + import io.cdap.wrangler.api.DirectiveExecutionException; + import io.cdap.wrangler.api.DirectiveParseException; + import io.cdap.wrangler.api.ExecutorContext; + import io.cdap.wrangler.api.Row; + import io.cdap.wrangler.api.annotations.Categories; + import io.cdap.wrangler.api.lineage.Lineage; + import io.cdap.wrangler.api.lineage.Mutation; + import io.cdap.wrangler.api.parser.ColumnName; + import io.cdap.wrangler.api.parser.Identifier; + import io.cdap.wrangler.api.parser.Token; + import io.cdap.wrangler.api.parser.TokenType; + import io.cdap.wrangler.api.parser.UsageDefinition; + + import java.math.BigDecimal; + import java.math.RoundingMode; + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.regex.Matcher; + import java.util.regex.Pattern; + + /** + * A directive that aggregates statistics for size and time duration columns. + * It calculates sum, average, minimum, maximum, and returns a count value of 1 (representing + * the single aggregated row) for the specified column. + * For size columns, it supports units like B, KB, MB, GB, TB, PB, KiB, MiB, GiB, TiB, PiB. + * For time columns, it supports units like ns, μs, ms, s, m, h, d. + */ + @Plugin(type = Directive.TYPE) + @Name(AggregateStatsDirective.NAME) + @Categories(categories = {"aggregate"}) + @Description("Aggregate statistics on byte size and time duration columns.") + public class AggregateStatsDirective implements Directive, Lineage { + public static final String NAME = "aggregate-stats"; + + private static final Pattern SIZE_PATTERN = Pattern.compile( + "(\\d+(\\.\\d+)?)\\s*(B|KB|MB|GB|TB|PB|KiB|MiB|GiB|TiB|PiB)?", + Pattern.CASE_INSENSITIVE); + + private static final Pattern TIME_PATTERN = Pattern.compile( + "(\\d+(\\.\\d+)?)\\s*(ns|μs|ms|s|m|h|d)?", + Pattern.CASE_INSENSITIVE); + + private static final Map SIZE_MULTIPLIERS = ImmutableMap.builder() + .put("B", 1.0) + .put("KB", 1024.0) + .put("MB", Math.pow(1024, 2)) + .put("GB", Math.pow(1024, 3)) + .put("TB", Math.pow(1024, 4)) + .put("PB", Math.pow(1024, 5)) + .build(); + + private static final Map BINARY_SIZE_MULTIPLIERS = ImmutableMap.builder() + .put("KIB", 1024.0) + .put("MIB", Math.pow(1024, 2)) + .put("GIB", Math.pow(1024, 3)) + .put("TIB", Math.pow(1024, 4)) + .put("PIB", Math.pow(1024, 5)) + .build(); + + private static final Map TIME_MULTIPLIERS = ImmutableMap.builder() + .put("NS", 1.0) + .put("μS", 1000.0) + .put("MS", 1_000_000.0) + .put("S", 1_000_000_000.0) + .put("M", 60.0 * 1_000_000_000) + .put("H", 3600.0 * 1_000_000_000) + .put("D", 86400.0 * 1_000_000_000) + .build(); + + private String column; + private String type; + private String outputUnit; + private Arguments arguments; + + // Holds intermediate aggregation values. + private static class AggregationState { + double sum = 0; + double min = Double.MAX_VALUE; + double max = Double.MIN_VALUE; + int count = 0; + } + + @Override + public UsageDefinition define() { + UsageDefinition.Builder builder = UsageDefinition.builder(NAME); + builder.define("column", TokenType.COLUMN_NAME); + builder.define("type", TokenType.IDENTIFIER); + builder.define("output_unit", TokenType.IDENTIFIER, true); // Optional output unit + return builder.build(); + } + + @Override + public void initialize(Arguments args) throws DirectiveParseException { + this.arguments = args; + this.column = ((ColumnName) args.value("column")).value(); + this.type = ((Identifier) args.value("type")).value(); + this.outputUnit = args.contains("output_unit") ? + ((Identifier) args.value("output_unit")).value() : null; + + if (!type.equalsIgnoreCase("SIZE") && !type.equalsIgnoreCase("DURATION")) { + throw new DirectiveParseException(NAME, "Invalid type. Expected SIZE or DURATION."); + } + + if (outputUnit != null) { + if (type.equalsIgnoreCase("SIZE")) { + if (!isValidSizeUnit(outputUnit)) { + throw new DirectiveParseException(NAME, "Invalid size unit: " + outputUnit); + } + } else { + if (!isValidTimeUnit(outputUnit)) { + throw new DirectiveParseException(NAME, "Invalid time unit: " + outputUnit); + } + } + } + } + + private boolean isValidSizeUnit(String unit) { + String upperUnit = unit.toUpperCase(); + return SIZE_MULTIPLIERS.containsKey(upperUnit) || BINARY_SIZE_MULTIPLIERS.containsKey(upperUnit); + } + + private boolean isValidTimeUnit(String unit) { + return TIME_MULTIPLIERS.containsKey(unit.toUpperCase()); + } + + @Override + public List execute(List rows, ExecutorContext context) throws DirectiveExecutionException { + AggregationState state = new AggregationState(); + String typeLower = type.toLowerCase(); + + // For each row, try to parse the value and aggregate. + for (Row row : rows) { + Object val = row.getValue(column); + if (val == null) { + continue; + } + + String strVal = val.toString(); + double parsed; + try { + if (typeLower.equals("size")) { + parsed = parseSize(strVal); + } else { + parsed = parseTime(strVal); + } + state.sum += parsed; + state.min = Math.min(state.min, parsed); + state.max = Math.max(state.max, parsed); + state.count++; + } catch (IllegalArgumentException e) { + throw new DirectiveExecutionException(NAME, e.getMessage()); + } + } + + // Create a single aggregate row. + Row result = new Row(); + if (state.count == 0) { + // When no valid values found, use default values. + if (typeLower.equals("size")) { + result.add("sum", formatSize(0, outputUnit)); + result.add("avg", formatSize(0, outputUnit)); + result.add("min", formatSize(0, outputUnit)); + result.add("max", formatSize(0, outputUnit)); + } else { + result.add("sum", formatTime(0, outputUnit)); + result.add("avg", formatTime(0, outputUnit)); + result.add("min", formatTime(0, outputUnit)); + result.add("max", formatTime(0, outputUnit)); + } + } else { + double avg = state.sum / state.count; + if (typeLower.equals("size")) { + result.add("sum", formatSize(state.sum, outputUnit)); + result.add("avg", formatSize(avg, outputUnit)); + result.add("min", formatSize(state.min, outputUnit)); + result.add("max", formatSize(state.max, outputUnit)); + } else { + result.add("sum", formatTime(state.sum, outputUnit)); + result.add("avg", formatTime(avg, outputUnit)); + result.add("min", formatTime(state.min, outputUnit)); + result.add("max", formatTime(state.max, outputUnit)); + } + } + + // Regardless of how many rows were aggregated, always set count to 1. + result.add("count", 1); + return Collections.singletonList(result); + } + + /** + * Parses a size string (e.g., "1KB", "2.5MB") into bytes. + * Throws IllegalArgumentException if the format or unit is invalid. + */ + private double parseSize(String value) { + Matcher matcher = SIZE_PATTERN.matcher(value); + if (!matcher.matches()) { + throw new IllegalArgumentException("Invalid size format: " + value); + } + + double number = Double.parseDouble(matcher.group(1)); + String unit = (matcher.group(3) != null) ? matcher.group(3).toUpperCase() : "B"; + + Double multiplier = SIZE_MULTIPLIERS.get(unit); + if (multiplier == null) { + multiplier = BINARY_SIZE_MULTIPLIERS.get(unit); + if (multiplier == null) { + throw new IllegalArgumentException("Invalid size unit: " + unit); + } + } + return number * multiplier; + } + + /** + * Parses a time string (e.g., "1s", "2.5m") into nanoseconds. + * Throws IllegalArgumentException if the format or unit is invalid. + */ + private double parseTime(String value) { + Matcher matcher = TIME_PATTERN.matcher(value); + if (!matcher.matches()) { + throw new IllegalArgumentException("Invalid time format: " + value); + } + + double number = Double.parseDouble(matcher.group(1)); + String unit = (matcher.group(3) != null) ? matcher.group(3).toUpperCase() : "S"; + + Double multiplier = TIME_MULTIPLIERS.get(unit); + if (multiplier == null) { + throw new IllegalArgumentException("Invalid time unit: " + unit); + } + return number * multiplier; + } + + /** + * Converts a byte value into a human-readable string (e.g., "1.23MB"). + */ + private String formatSize(double value, String outputUnit) { + if (outputUnit != null) { + double multiplier = getSizeMultiplier(outputUnit); + return String.format("%.2f%s", value / multiplier, outputUnit); + } + + String[] units = {"B", "KB", "MB", "GB", "TB", "PB"}; + int idx = 0; + while (value >= 1024 && idx < units.length - 1) { + value /= 1024; + idx++; + } + return String.format("%.2f%s", value, units[idx]); + } + + /** + * Converts a time value in nanoseconds to a human-readable string (e.g., "1.23s"). + */ + private String formatTime(double value, String outputUnit) { + if (outputUnit != null) { + double multiplier = getTimeMultiplier(outputUnit); + return String.format("%.2f%s", value / multiplier, outputUnit); + } + + String[] units = {"ns", "μs", "ms", "s", "m", "h", "d"}; + double[] multipliers = {1, 1_000.0, 1_000_000.0, 1_000_000_000.0, + 60.0 * 1_000_000_000.0, 3600.0 * 1_000_000_000.0, + 86400.0 * 1_000_000_000.0}; + + int idx = units.length - 1; + while (idx > 0 && value < multipliers[idx]) { + idx--; + } + return String.format("%.2f%s", value / multipliers[idx], units[idx]); + } + + private double getSizeMultiplier(String unit) { + String upperUnit = unit.toUpperCase(); + Double multiplier = SIZE_MULTIPLIERS.get(upperUnit); + if (multiplier == null) { + multiplier = BINARY_SIZE_MULTIPLIERS.get(upperUnit); + } + return multiplier; + } + + private double getTimeMultiplier(String unit) { + return TIME_MULTIPLIERS.get(unit.toUpperCase()); + } + + @Override + public void destroy() { + // No cleanup needed. + } + + @Override + public Mutation lineage() { + return Mutation.builder() + .readable("Aggregates statistics for size and time columns") + .build(); + } + } + + diff --git a/wrangler-core/src/main/java/io/cdap/functions/JsonFunctions.java b/wrangler-core/src/main/java/io/cdap/functions/JsonFunctions.java index d07f24cfb..da3b81d02 100644 --- a/wrangler-core/src/main/java/io/cdap/functions/JsonFunctions.java +++ b/wrangler-core/src/main/java/io/cdap/functions/JsonFunctions.java @@ -324,7 +324,6 @@ public static String Stringify(JsonElement element) { /** * @return Number of elements in the array. */ - @Nullable public static int ArrayLength(JsonArray array) { if (array != null) { return array.size(); diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/directives/AggregateStatsDirective.java b/wrangler-core/src/main/java/io/cdap/wrangler/directives/AggregateStatsDirective.java new file mode 100644 index 000000000..fbdab4047 --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/directives/AggregateStatsDirective.java @@ -0,0 +1,108 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.directives; + +import io.cdap.wrangler.api.Arguments; +import io.cdap.wrangler.api.Directive; +import io.cdap.wrangler.api.ExecutorContext; +import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.annotations.Categories; +import io.cdap.wrangler.api.parser.ByteSize; +import io.cdap.wrangler.api.parser.ColumnName; +import io.cdap.wrangler.api.parser.TimeDuration; +import io.cdap.wrangler.api.parser.TokenType; +import io.cdap.wrangler.api.parser.UsageDefinition; + +import java.util.ArrayList; +import java.util.List; + +/** + * A directive for aggregating byte sizes and time durations across rows. + */ +@Categories(categories = { "transform" }) +public class AggregateStatsDirective implements Directive { + private ColumnName sizeColumn; + private ColumnName timeColumn; + private ColumnName totalSizeColumn; + private ColumnName totalTimeColumn; + private long totalBytes = 0; + private long totalNanoseconds = 0; + private int rowCount = 0; + + @Override + public UsageDefinition define() { + UsageDefinition.Builder builder = UsageDefinition.builder("aggregate-stats"); + builder.define("size-column", TokenType.COLUMN_NAME); + builder.define("time-column", TokenType.COLUMN_NAME); + builder.define("total-size-column", TokenType.COLUMN_NAME); + builder.define("total-time-column", TokenType.COLUMN_NAME); + return builder.build(); + } + + @Override + public void initialize(Arguments args) { + sizeColumn = args.value("size-column"); + timeColumn = args.value("time-column"); + totalSizeColumn = args.value("total-size-column"); + totalTimeColumn = args.value("total-time-column"); + } + + @Override + public List execute(List rows, ExecutorContext context) { + for (Row row : rows) { + // Get size value and convert to bytes + Object sizeValue = row.getValue(sizeColumn.value()); + if (sizeValue != null) { + String sizeStr = sizeValue.toString(); + try { + ByteSize byteSize = new ByteSize(sizeStr); + totalBytes += byteSize.getBytes(); + } catch (IllegalArgumentException e) { + // Skip invalid byte size values + } + } + + // Get time value and convert to nanoseconds + Object timeValue = row.getValue(timeColumn.value()); + if (timeValue != null) { + String timeStr = timeValue.toString(); + try { + TimeDuration timeDuration = new TimeDuration(timeStr); + totalNanoseconds += timeDuration.getNanoseconds(); + } catch (IllegalArgumentException e) { + // Skip invalid time duration values + } + } + + rowCount++; + } + + // Create a new row with the aggregated values + Row result = new Row(); + result.add(totalSizeColumn.value(), String.format("%.2f MB", totalBytes / (1024.0 * 1024))); + result.add(totalTimeColumn.value(), String.format("%.2f s", totalNanoseconds / (1000.0 * 1000 * 1000))); + + List results = new ArrayList<>(); + results.add(result); + return results; + } + + @Override + public void destroy() { + // Clean up any resources if needed + } +} \ No newline at end of file diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/dq/ConvertDistances.java b/wrangler-core/src/main/java/io/cdap/wrangler/dq/ConvertDistances.java index 1be87b116..17534fd34 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/dq/ConvertDistances.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/dq/ConvertDistances.java @@ -103,7 +103,7 @@ public ConvertDistances() { this(Distance.MILE, Distance.KILOMETER); } - @Nullable + public ConvertDistances(Distance from, Distance to) { this.from = (from == null ? Distance.MILE : from); this.to = (to == null ? Distance.KILOMETER : to); diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/expression/ELContext.java b/wrangler-core/src/main/java/io/cdap/wrangler/expression/ELContext.java index 04b0b884b..75a1a86b2 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/expression/ELContext.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/expression/ELContext.java @@ -91,7 +91,7 @@ public ELContext(ExecutorContext context, EL el, Row row) { set("this", row); } - @Nullable + private void init(ExecutorContext context) { if (context != null) { // Adds the transient store variables. diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/parser/RecipeParserVisitor.java b/wrangler-core/src/main/java/io/cdap/wrangler/parser/RecipeParserVisitor.java new file mode 100644 index 000000000..a0dd1a8b6 --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/parser/RecipeParserVisitor.java @@ -0,0 +1,36 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.parser; + +import io.cdap.wrangler.api.parser.Token; +import io.cdap.wrangler.api.parser.ByteSize; +import io.cdap.wrangler.api.parser.TimeDuration; +import io.cdap.wrangler.parser.grammar.DirectivesParser; +import io.cdap.wrangler.parser.grammar.DirectivesBaseVisitor; + +public class RecipeParserVisitor extends DirectivesBaseVisitor { + + @Override + public Token visitByteSize(DirectivesParser.ByteSizeContext ctx) { + return new ByteSize(ctx.getText()); + } + + @Override + public Token visitTimeDuration(DirectivesParser.TimeDurationContext ctx) { + return new TimeDuration(ctx.getText()); + } +} \ No newline at end of file diff --git a/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsDirectiveTest.java b/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsDirectiveTest.java new file mode 100644 index 000000000..73ab4c3fd --- /dev/null +++ b/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsDirectiveTest.java @@ -0,0 +1,225 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + + +package io.cdap.directives.aggregates; + +import io.cdap.wrangler.TestingRig; +import io.cdap.wrangler.api.DirectiveExecutionException; +import io.cdap.wrangler.api.DirectiveParseException; +import io.cdap.wrangler.api.Row; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for {@link AggregateStats} directive. + */ +public class AggregateStatsDirectiveTest { + + @Test + public void testBasicSizeAggregation() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1KB")); + rows.add(createRow("size", "2KB")); + rows.add(createRow("size", "3KB")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("6.00KB", result.getValue("sum")); + Assert.assertEquals("2.00KB", result.getValue("avg")); + Assert.assertEquals("1.00KB", result.getValue("min")); + Assert.assertEquals("3.00KB", result.getValue("max")); + Assert.assertEquals(3, result.getValue("count")); + } + + @Test + public void testBasicTimeAggregation() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("duration", "1s")); + rows.add(createRow("duration", "2s")); + rows.add(createRow("duration", "3s")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :duration DURATION"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("6.00s", result.getValue("sum")); + Assert.assertEquals("2.00s", result.getValue("avg")); + Assert.assertEquals("1.00s", result.getValue("min")); + Assert.assertEquals("3.00s", result.getValue("max")); + Assert.assertEquals(3, result.getValue("count")); + } + + @Test + public void testDifferentSizeUnits() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1MB")); + rows.add(createRow("size", "1024KB")); + rows.add(createRow("size", "1048576B")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("3.00MB", result.getValue("sum")); + Assert.assertEquals("1.00MB", result.getValue("avg")); + Assert.assertEquals("1.00MB", result.getValue("min")); + Assert.assertEquals("1.00MB", result.getValue("max")); + Assert.assertEquals(3, result.getValue("count")); + } + + @Test + public void testDifferentTimeUnits() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("duration", "1h")); + rows.add(createRow("duration", "60m")); + rows.add(createRow("duration", "3600s")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :duration DURATION"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("3.00h", result.getValue("sum")); + Assert.assertEquals("1.00h", result.getValue("avg")); + Assert.assertEquals("1.00h", result.getValue("min")); + Assert.assertEquals("1.00h", result.getValue("max")); + Assert.assertEquals(3, result.getValue("count")); + } + + @Test + public void testBinaryUnits() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1KiB")); + rows.add(createRow("size", "1MiB")); + rows.add(createRow("size", "1GiB")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertNotNull(result.getValue("sum")); + Assert.assertNotNull(result.getValue("avg")); + Assert.assertNotNull(result.getValue("min")); + Assert.assertNotNull(result.getValue("max")); + Assert.assertEquals(3, result.getValue("count")); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidSizeUnit() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "10XB")); + + TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidTimeUnit() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("duration", "10x")); + + TestingRig.execute(new String[]{"aggregate-stats :duration DURATION"}, rows); + } + + @Test + public void testEmptyRows() throws Exception { + List rows = new ArrayList<>(); + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("0.00B", result.getValue("sum")); + Assert.assertEquals("0.00B", result.getValue("avg")); + Assert.assertEquals("0.00B", result.getValue("min")); + Assert.assertEquals("0.00B", result.getValue("max")); + Assert.assertEquals(0, result.getValue("count")); + } + + @Test + public void testNullValues() throws Exception { + List rows = new ArrayList<>(); + Row row1 = new Row(); + row1.add("size", null); + rows.add(row1); + + Row row2 = new Row(); + row2.add("size", "10MB"); + rows.add(row2); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("10.00MB", result.getValue("sum")); + Assert.assertEquals("10.00MB", result.getValue("avg")); + Assert.assertEquals("10.00MB", result.getValue("min")); + Assert.assertEquals("10.00MB", result.getValue("max")); + Assert.assertEquals(1, result.getValue("count")); + } + + @Test + public void testSizeWithOutputUnit() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1KB")); + rows.add(createRow("size", "2KB")); + rows.add(createRow("size", "3KB")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE MB"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("0.01MB", result.getValue("sum")); + Assert.assertEquals("0.00MB", result.getValue("avg")); + Assert.assertEquals("0.00MB", result.getValue("min")); + Assert.assertEquals("0.00MB", result.getValue("max")); + Assert.assertEquals(3, result.getValue("count")); + } + + @Test + public void testTimeWithOutputUnit() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("duration", "1s")); + rows.add(createRow("duration", "2s")); + rows.add(createRow("duration", "3s")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :duration DURATION m"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("0.10m", result.getValue("sum")); + Assert.assertEquals("0.03m", result.getValue("avg")); + Assert.assertEquals("0.02m", result.getValue("min")); + Assert.assertEquals("0.05m", result.getValue("max")); + Assert.assertEquals(3, result.getValue("count")); + } + + @Test(expected = DirectiveParseException.class) + public void testInvalidOutputUnit() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1KB")); + + TestingRig.execute(new String[]{"aggregate-stats :size SIZE XB"}, rows); + } + + private Row createRow(String column, String value) { + Row row = new Row(); + row.add(column, value); + return row; + } +} diff --git a/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsRegistrationTest.java b/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsRegistrationTest.java new file mode 100644 index 000000000..729e09ee9 --- /dev/null +++ b/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsRegistrationTest.java @@ -0,0 +1,61 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.directives.aggregates; + +import io.cdap.wrangler.registry.DirectiveInfo; +import io.cdap.wrangler.registry.SystemDirectiveRegistry; +import org.junit.Assert; +import org.junit.Test; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +/** + * Tests for {@link AggregateStatsDirective} directive registration. + */ +public class AggregateStatsRegistrationTest { + + @Test + public void testDirectiveRegistration() { + // Get the system directive registry + SystemDirectiveRegistry registry = SystemDirectiveRegistry.INSTANCE; + + // Get all registered directives + Iterable directives = registry.list("aggregate"); + + // Convert to list for easier processing + List directiveList = StreamSupport.stream(directives.spliterator(), false) + .collect(Collectors.toList()); + + // Find the aggregate-stats directive + DirectiveInfo aggregateStatsInfo = directiveList.stream() + .filter(info -> "aggregate-stats".equals(info.name())) + .findFirst() + .orElse(null); + + // Verify the directive is registered + Assert.assertNotNull("aggregate-stats directive should be registered", aggregateStatsInfo); + + // Verify the directive name + Assert.assertEquals("aggregate-stats", aggregateStatsInfo.name()); + + // Verify the directive class + Assert.assertEquals(AggregateStatsDirective.class.getName(), + aggregateStatsInfo.getDirectiveClass().getClassName()); + } +} diff --git a/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsTest.java b/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsTest.java new file mode 100644 index 000000000..6bb045c7a --- /dev/null +++ b/wrangler-core/src/test/java/io/cdap/directives/aggregates/AggregateStatsTest.java @@ -0,0 +1,171 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.directives.aggregates; + +import io.cdap.wrangler.TestingRig; +import io.cdap.wrangler.api.Row; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for {@link AggregateStats} implementation. + */ + +public class AggregateStatsTest { + + @Test + public void testBasicSizeAggregation() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1KB")); + rows.add(createRow("size", "2KB")); + rows.add(createRow("size", "3KB")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("3KB", result.getValue("sum")); + Assert.assertEquals("2KB", result.getValue("avg")); + Assert.assertEquals("1KB", result.getValue("min")); + Assert.assertEquals("3KB", result.getValue("max")); + } + + @Test + public void testBasicTimeAggregation() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("duration", "1s")); + rows.add(createRow("duration", "2s")); + rows.add(createRow("duration", "3s")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :duration TIME"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("6s", result.getValue("sum")); + Assert.assertEquals("2s", result.getValue("avg")); + Assert.assertEquals("1s", result.getValue("min")); + Assert.assertEquals("3s", result.getValue("max")); + } + + @Test + public void testDifferentSizeUnits() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1MB")); + rows.add(createRow("size", "1024KB")); + rows.add(createRow("size", "1048576B")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("3MB", result.getValue("sum")); + Assert.assertEquals("1MB", result.getValue("avg")); + Assert.assertEquals("1MB", result.getValue("min")); + Assert.assertEquals("1MB", result.getValue("max")); + } + + @Test + public void testDifferentTimeUnits() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("duration", "1h")); + rows.add(createRow("duration", "60m")); + rows.add(createRow("duration", "3600s")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :duration TIME"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("3h", result.getValue("sum")); + Assert.assertEquals("1h", result.getValue("avg")); + Assert.assertEquals("1h", result.getValue("min")); + Assert.assertEquals("1h", result.getValue("max")); + } + + @Test + public void testBinaryUnits() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "1KiB")); + rows.add(createRow("size", "1MiB")); + rows.add(createRow("size", "1GiB")); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertNotNull(result.getValue("sum")); + Assert.assertNotNull(result.getValue("avg")); + Assert.assertNotNull(result.getValue("min")); + Assert.assertNotNull(result.getValue("max")); + } + + @Test(expected = Exception.class) + public void testInvalidSizeUnit() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("size", "10XB")); + + TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + } + + @Test(expected = Exception.class) + public void testInvalidTimeUnit() throws Exception { + List rows = new ArrayList<>(); + rows.add(createRow("duration", "10x")); + + TestingRig.execute(new String[]{"aggregate-stats :duration TIME"}, rows); + } + + @Test + public void testEmptyRows() throws Exception { + List rows = new ArrayList<>(); + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + Assert.assertEquals(0, results.size()); + } + + @Test + public void testNullValues() throws Exception { + List rows = new ArrayList<>(); + Row row1 = new Row(); + row1.add("size", null); + rows.add(row1); + + Row row2 = new Row(); + row2.add("size", "10MB"); + rows.add(row2); + + List results = TestingRig.execute(new String[]{"aggregate-stats :size SIZE"}, rows); + + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + Assert.assertEquals("10MB", result.getValue("sum")); + Assert.assertEquals("10MB", result.getValue("avg")); + Assert.assertEquals("10MB", result.getValue("min")); + Assert.assertEquals("10MB", result.getValue("max")); + } + + /** + * Helper method to create a row with a single column. + */ + private Row createRow(String column, String value) { + Row row = new Row(); + row.add(column, value); + return row; + } + +} diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/directives/AggregateStatsDirectiveTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/directives/AggregateStatsDirectiveTest.java new file mode 100644 index 000000000..d39177499 --- /dev/null +++ b/wrangler-core/src/test/java/io/cdap/wrangler/directives/AggregateStatsDirectiveTest.java @@ -0,0 +1,312 @@ +/* + * Copyright © 2017-2019 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.directives; + +import io.cdap.wrangler.api.Arguments; +import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.parser.ColumnName; +import io.cdap.wrangler.api.parser.Token; +import io.cdap.wrangler.api.parser.TokenType; +import io.cdap.wrangler.api.parser.UsageDefinition; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +public class AggregateStatsDirectiveTest { + @Test + public void testDefine() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + UsageDefinition definition = directive.define(); + + Assert.assertEquals("aggregate-stats", definition.getDirectiveName()); + Assert.assertEquals(4, definition.getTokens().size()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(0).type()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(1).type()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(2).type()); + Assert.assertEquals(TokenType.COLUMN_NAME, definition.getTokens().get(3).type()); + } + + @Test + public void testSizeAndTimeCalculations() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + + // Create test rows with various size and time units + List rows = new ArrayList<>(); + + // Row 1: 10MB and 100ms + Row row1 = new Row(); + row1.add("data_transfer_size", "10MB"); + row1.add("response_time", "100ms"); + rows.add(row1); + + // Row 2: 5MB and 200ms + Row row2 = new Row(); + row2.add("data_transfer_size", "5MB"); + row2.add("response_time", "200ms"); + rows.add(row2); + + // Row 3: 1GB and 1s + Row row3 = new Row(); + row3.add("data_transfer_size", "1GB"); + row3.add("response_time", "1s"); + rows.add(row3); + + // Initialize directive + directive.initialize(new Arguments() { + @Override + public T value(String name) { + switch (name) { + case "size-column": + return (T) new ColumnName("data_transfer_size"); + case "time-column": + return (T) new ColumnName("response_time"); + case "total-size-column": + return (T) new ColumnName("total_size_mb"); + case "total-time-column": + return (T) new ColumnName("total_time_sec"); + default: + return null; + } + } + + @Override + public int size() { + return 4; + } + + @Override + public boolean contains(String name) { + return true; + } + + @Override + public TokenType type(String name) { + return TokenType.COLUMN_NAME; + } + + @Override + public int line() { + return 0; + } + + @Override + public int column() { + return 0; + } + + @Override + public String source() { + return ""; + } + + @Override + public com.google.gson.JsonElement toJson() { + return null; + } + }); + + // Execute directive + List results = directive.execute(rows, null); + + // Verify results + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + + // Expected calculations: + // Size: 10MB + 5MB + 1GB = 10MB + 5MB + 1024MB = 1039MB + Assert.assertEquals("1039.00 MB", result.getValue("total_size_mb")); + + // Time: 100ms + 200ms + 1s = 0.1s + 0.2s + 1s = 1.3s + Assert.assertEquals("1.30 s", result.getValue("total_time_sec")); + } + + @Test + public void testMixedUnits() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + + // Create test rows with mixed units + List rows = new ArrayList<>(); + + // Row 1: 1GB and 1s + Row row1 = new Row(); + row1.add("data_transfer_size", "1GB"); + row1.add("response_time", "1s"); + rows.add(row1); + + // Row 2: 1024KB and 1000ms + Row row2 = new Row(); + row2.add("data_transfer_size", "1024KB"); + row2.add("response_time", "1000ms"); + rows.add(row2); + + // Initialize directive + directive.initialize(new Arguments() { + @Override + public T value(String name) { + switch (name) { + case "size-column": + return (T) new ColumnName("data_transfer_size"); + case "time-column": + return (T) new ColumnName("response_time"); + case "total-size-column": + return (T) new ColumnName("total_size_mb"); + case "total-time-column": + return (T) new ColumnName("total_time_sec"); + default: + return null; + } + } + + @Override + public int size() { + return 4; + } + + @Override + public boolean contains(String name) { + return true; + } + + @Override + public TokenType type(String name) { + return TokenType.COLUMN_NAME; + } + + @Override + public int line() { + return 0; + } + + @Override + public int column() { + return 0; + } + + @Override + public String source() { + return ""; + } + + @Override + public com.google.gson.JsonElement toJson() { + return null; + } + }); + + // Execute directive + List results = directive.execute(rows, null); + + // Verify results + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + + // Expected calculations: + // Size: 1GB + 1024KB = 1024MB + 1MB = 1025MB + Assert.assertEquals("1025.00 MB", result.getValue("total_size_mb")); + + // Time: 1s + 1000ms = 1s + 1s = 2s + Assert.assertEquals("2.00 s", result.getValue("total_time_sec")); + } + + @Test + public void testInvalidValues() { + AggregateStatsDirective directive = new AggregateStatsDirective(); + + // Create test rows with some invalid values + List rows = new ArrayList<>(); + + // Row 1: Valid values + Row row1 = new Row(); + row1.add("data_transfer_size", "10MB"); + row1.add("response_time", "100ms"); + rows.add(row1); + + // Row 2: Invalid values + Row row2 = new Row(); + row2.add("data_transfer_size", "invalid"); + row2.add("response_time", "invalid"); + rows.add(row2); + + // Initialize directive + directive.initialize(new Arguments() { + @Override + public T value(String name) { + switch (name) { + case "size-column": + return (T) new ColumnName("data_transfer_size"); + case "time-column": + return (T) new ColumnName("response_time"); + case "total-size-column": + return (T) new ColumnName("total_size_mb"); + case "total-time-column": + return (T) new ColumnName("total_time_sec"); + default: + return null; + } + } + + @Override + public int size() { + return 4; + } + + @Override + public boolean contains(String name) { + return true; + } + + @Override + public TokenType type(String name) { + return TokenType.COLUMN_NAME; + } + + @Override + public int line() { + return 0; + } + + @Override + public int column() { + return 0; + } + + @Override + public String source() { + return ""; + } + + @Override + public com.google.gson.JsonElement toJson() { + return null; + } + }); + + // Execute directive + List results = directive.execute(rows, null); + + // Verify results - should only count valid values + Assert.assertEquals(1, results.size()); + Row result = results.get(0); + + // Only the valid values should be counted + Assert.assertEquals("10.00 MB", result.getValue("total_size_mb")); + Assert.assertEquals("0.10 s", result.getValue("total_time_sec")); + } +} + \ No newline at end of file diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/SchemaGenerationTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/SchemaGenerationTest.java index 03c64042c..dc9d18797 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/SchemaGenerationTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/SchemaGenerationTest.java @@ -1,17 +1,17 @@ /* - * Copyright © 2023 Cask Data, Inc. + * Copyright © 2017-2019 Cask Data, Inc. * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. */ package io.cdap.wrangler.executor;